DELE ST1504 CA2 Part A: Generative Adversarial Network


Name: Lee Hong Yi & Yadanar Aung
Admin No: 2223010 & 2214621
Class: DAAA/FT/2B/07

Objective:
Develop a Generative Adversarial Network (GAN) model for image generation, utilizing the CIFAR10 dataset. The model aims to generate 1000 high-quality, small color images in 10 distinct classes, showcasing its ability to learn and replicate complex visual patterns.

Background:
GANs are a revolutionary class of artificial neural networks used in unsupervised machine learning tasks. They consist of two parts: a Generator, which creates images, and a Discriminator, which evaluates them. The objective is to train a GAN that excels in producing diverse, realistic images that closely mimic the characteristics of the CIFAR10 dataset.

Key Features:
Implement and evaluate different GAN architectures to determine the most effective model for the CIFAR10 specific image generation task, which should generate images that not only are visually appealing and realistic but also display a wide range of creativity within the constraints of the 10 classes in the dataset.

Output Specification:
The model will produce images that are evaluated based on their similarity to the real images in the CIFAR10 dataset and their diversity across the dataset's classes. The performance of the GAN will be a crucial indicator of its effectiveness in learning and replicating complex patterns from a given dataset.


Performing initial set-up
In [ ]:
import gc
import os
import numpy as np
import pandas as pd
import seaborn as sns
import tensorflow as tf
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt
from tensorflow.keras import Model
from skimage.transform import resize
from tensorflow.keras.callbacks import Callback
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.losses import BinaryCrossentropy, Hinge
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.regularizers import l1, l2, l1_l2
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.layers import Dense, Reshape, UpSampling2D, Conv2D, BatchNormalization, LeakyReLU, ZeroPadding2D, Dropout, Flatten, Input, Activation, GlobalMaxPooling2D, Conv2DTranspose, PReLU, Embedding, Concatenate
from tensorflow.keras.metrics import Mean
#from tensorflow_addons.layers import SpectralNormalization

import GAN_function as gnnf
In [ ]:
from warnings import simplefilter
simplefilter(action='ignore', category=UserWarning)
simplefilter(action='ignore', category=FutureWarning)
In [ ]:
# Fix random seed for reproducibility
seed = 1
np.random.seed(seed)
tf.random.set_seed(seed)
tf.keras.utils.set_random_seed(0)
In [ ]:
# Check GPU is available
gpus = tf.config.experimental.list_physical_devices('GPU')

# Memory control: Prevent tensorflow from allocating totality of GPU memory
for gpu in gpus:
    try:
        print(tf.config.experimental.get_device_details(gpu))
    except:
        pass
    tf.config.experimental.set_memory_growth(gpu, True)
print(f"There are {len(gpus)} GPU(s) present.")
{'device_name': 'NVIDIA GeForce RTX 3060', 'compute_capability': (8, 6)}
There are 1 GPU(s) present.

Background Research

CIFAR10 Dataset:

  • The CIFAR10 (Canadian Institute for Advanced Research) dataset consists of 60,000 colour images in 10 classes.
  • There are 6,000 images per class.

Images:

  • The images are split into 50,000 train images and 10,000 test images.
  • The images are of size 32x32.

Classes:

  • Total of 10 distinct classes:
    1. Airplane
    2. Automoblie
    3. Bird
    4. Cat
    5. Deer
    6. Dog
    7. Frog
    8. Horse
    9. Ship
    10. Truck
  • Classes are mutually exclusive.
    • There is no overlap between automobiles and trucks, neither includes pickup trucks.
    • "Automobile" includes sedans, SUVs, etc.
    • "Truck" includes only big trucks.

Batches:

  • The dataset is divided into 5 train batches & 1 test batch, each with 10,000 images.
  • Train batches contain 50,000 images in total from each class in random order
    • Some batches contain more images from one class than another
  • Test batch contains 10,000 randomly-selected images from each class

Source: https://www.cs.toronto.edu/~kriz/cifar.html


Load CIFAR10 Dataset
It returns two tuples: (x_train, y_train), (x_test, y_test). The first element of each tuple is an array of images. The second element is an array of corresponding labels.

x_train: uint8 NumPy array of grayscale image data with shapes (50000, 32, 32, 3), containing the training data. Pixel values range from 0 to 255.

y_train: uint8 NumPy array of labels (integers in range 0-9) with shape (50000, 1) for the training data.

x_test: uint8 NumPy array of grayscale image data with shapes (10000, 32, 32, 3), containing the test data. Pixel values range from 0 to 255.

y_test: uint8 NumPy array of labels (integers in range 0-9) with shape (10000, 1) for the test data.

Source: https://keras.io/api/datasets/cifar10/

In [ ]:
# Load CIFAR10 Dataset
cifar10 = tf.keras.datasets.cifar10.load_data()

Split into Train Dataset

In [ ]:
# Load CIFAR-10 dataset
(X_train, y_train), (X_test, y_test) = cifar10

# Combine Train and Test datasets
# X_train = np.concatenate((X_train, X_test), axis=0)
# y_train = np.concatenate((y_train, y_test), axis=0)

# For EDA Purposes
eda_data = X_train

# Print the shapes of the combined datasets
print(f"Shape of combined X (features): {X_train.shape}")
print(f"Shape of combined y (labels): {y_train.shape}")
Shape of combined X (features): (50000, 32, 32, 3)
Shape of combined y (labels): (50000, 1)

Define Class Labels

In [ ]:
# Map integer class labels to their corresponding class names
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

Exploratory Data Analysis (EDA)

Image Pixel Distribution

In [ ]:
# Obtain Pixel Statistics
min = np.min(eda_data, axis=(0, 1, 2))
max = np.max(eda_data, axis=(0, 1, 2))
mean = np.mean(eda_data, axis=(0, 1, 2))
std = np.std(eda_data, axis=(0, 1, 2))

# Print Statistics
print("\nPixel Statistics for the original train dataset:")
print(f"Minimum pixel value: {min}")
print(f"Maximum pixel value: {max}")
print(f"Mean pixel value: {mean}")
print(f"Standard deviation of pixel values: {std}")
Pixel Statistics for the original train dataset:
Minimum pixel value: [0 0 0]
Maximum pixel value: [255 255 255]
Mean pixel value: [125.30691805 122.95039414 113.86538318]
Standard deviation of pixel values: [62.99321928 62.08870764 66.70489964]

Class Distribution

In [ ]:
classes, counts = np.unique(y_train, return_counts=True)
class_count_dict = dict(zip(class_names, counts))
df = pd.DataFrame({'Count': class_count_dict})
df
Out[ ]:
Count
Airplane 5000
Automobile 5000
Bird 5000
Cat 5000
Deer 5000
Dog 5000
Frog 5000
Horse 5000
Ship 5000
Truck 5000
In [ ]:
# Visualise Distribution of Image Classes
gnnf.plot_counts(class_count_dict)
gnnf.plot_pie_chart(class_count_dict)

Insights:

  • Training dataset is balanced across the different classes.

Dataset Visualization

Display sample of images from each of the 10 classes.

In [ ]:
# Create figure & set size
fig, axes = plt.subplots(10, 10, figsize=(30, 30))

for i in range(len(class_names)):
    class_indices = np.where(y_train == i)[0]
    # Randomly select ten images
    random_indices = np.random.choice(class_indices, 10, replace=False)

    for j, image_index in enumerate(random_indices):
        axes[i, j].imshow(eda_data[image_index])
        # Load image
        axes[i, j].axis('off')
        axes[i, j].set_title(class_names[i])

plt.suptitle('10 Random Images Per Class', fontsize=25)
plt.tight_layout(rect=[0, 0, 1, 0.97])
plt.show()

Insights:

  • There is a wide variety presented in the images, from the color to the oritentation.
    • Images vary from a close-up shot of the face of the animals to the entire body.
    • Images vary in terms of capturing the front, back, side and top view of the objects.
    • Images within the same classes have different species, for example "Birds" have small and large birds like Ostriches.
  • A challenge identified is that there are certain texts in the images, as seen in the "Airplane" and "Cat" images.
    • This might cause the model to unintentionally generate images of texts.
  • The quality of the image is rather pixelated, hence certain features of the different classes may overlap.
  • The image size is 32x32, hence the impact of the small resolution may make it challenging for the model to learn the distinct features of each classes.

Image Averaging for Pixel Distribution

In [ ]:
gnnf.average_image(eda_data)

Insights:

  • From the above, which is a average of all the pictures in the CIFAR10 dataset, there are no outstanding features across the ten classes
  • The dataset also appears to be not uniform, as the image is blurry with no clear indication of a class
  • The dataset also seems to be homogenous, with most of the features being shared accross the images
  • These may pose challenges later on, as the models may not have enough features to identify the different classes
In [ ]:
gnnf.average_images_per_class(eda_data, y_train, class_names)

Insights:

  • The 4 distinct colors in the dataset are Blue, Green, Brown and Red.
  • When the images are averaged, we can still identify the outlines of an Automobile and Horse, implying that the shape of the classes will matter to the models.
  • Most of the images have a concentrated color in the center.

Feature Engineering

Normalization

Currently, the input pixel sizes are in the range [0, 255].

Hence, we are going to rescale the pixel values to the range [-1,1], so that the model can train more efficiently as pixel inputs with large integer values can slow down the training process.

In [ ]:
# Scale from [0,255] to [-1,1]
X_train_rescaled = (X_train / 127.5 - 1.).astype('float32')

# Obtain Pixel Statistics After Rescaling
min = np.min(X_train_rescaled, axis=(0, 1, 2))
max = np.max(X_train_rescaled, axis=(0, 1, 2))
mean = np.mean(X_train_rescaled, axis=(0, 1, 2))
std = np.std(X_train_rescaled, axis=(0, 1, 2))

# Print Statistics
print("Pixel Statistics for the original train dataset:")
print(f"Minimum pixel value: {min}")
print(f"Maximum pixel value: {max}")
print(f"Mean pixel value: {mean}")
print(f"Standard deviation of pixel values: {std}\n")
Pixel Statistics for the original train dataset:
Minimum pixel value: [-1. -1. -1.]
Maximum pixel value: [1. 1. 1.]
Mean pixel value: [-0.01720063 -0.03568322 -0.10693764]
Standard deviation of pixel values: [0.49406543 0.48696858 0.52317506]

Visualize Before vs After Rescaling

In [ ]:
# Plot first image from eda_data
plt.subplot(1, 2, 1)
plt.imshow(eda_data[0])
plt.title('Original Image')

# Plot first image from X_train_dataAug
plt.subplot(1, 2, 2)
plt.imshow(X_train_rescaled[0])
plt.title('Rescaled Image')

plt.show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).

Final Datasets for GAN Model Training

Finally, we are going to convert our finalized datasets to Tensor, from ndarray, for model training.

In [ ]:
# Convert frpm ndarray to Tensor
# X_train_rescaled_final = tf.convert_to_tensor((X_train_rescaled))
y_train = tf.convert_to_tensor((y_train))

X_train = tf.data.Dataset.from_tensor_slices((X_train_rescaled, y_train))
X_train = X_train.shuffle(1000).batch(32, drop_remainder=True)
X_train
Out[ ]:
<BatchDataset element_spec=(TensorSpec(shape=(32, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(32, 1), dtype=tf.uint8, name=None))>
In [ ]:
y_test = tf.convert_to_tensor((y_test))


X_test_rescaled = (X_test / 127.5 - 1.).astype('float32')
X_test = tf.data.Dataset.from_tensor_slices((X_test_rescaled, y_test))
X_test = X_test.shuffle(1000).batch(32, drop_remainder=True)

GAN Model Evaluation Methodology

As both the Generator and Discriminator are trained together to maintain an equilibirum in a zero-sum game, there is no objective loss function alone to evaluate the performance of the two models.

When measuring the performance of the GAN models, there are two properties to evaluate:

  1. Fidelity: Quality of generated images to measure how realistic the images are
  2. Diversity: Variety of generated images to measure if it covers the whole variety of the real distribution

A number of qualitative and quantitative techniques has been developed to evaluate the performance of the models, based on the quality and diversity of the generated images.

  1. Manual GAN Evaluation
  • It involves the manual visual inspection of the images generated by the Generator
  • Simplest method
  • Limitations include:
    • It is subjective and reviewer may be bias
    • The images from the dataset are of size 32x32, hence leading to difficult in differentiating the classses
    • Reviewing a large amount of generated images is time-consuming
  1. Qualitative GAN Evaluation
  • It involves human subjective evaluation or evaluation via comparison
  • Qualitative Techniques include:
    1. Nearest Neighbors: Detects overfitting, generated samples are shown next to their nearest neighbors in the training set
    2. Rapid Scene Categorization: Participants are to distinguish samples from real images in a short view time
    3. Rating & Preference Judgement: Participants to rank models in terms of the fidelity of their generated images
    4. Evaluating Mode Drop & Mode Collapse: Over datasets with known modes, modes are computed as by measuring distances of generated data to mode centers
    5. Investigating & Visualising the Internals of Networks: Explore & illustrate internal representation & dynamics of models, as well as visualizing learned features
  1. Quantitative GAN Evaluation
  • Refers to calculation of specific numerical scores used to summarize quality of generated images.
  • Quantitative Techniques include:
    1. Fréchet Inception Distance (FID): Wasserstein-2 distance between multi-variate Guassians fitted to data embedded into a feature space, to evaluate quality of generated images
    2. Kullback-Leibler (KL) Divergence: Measure of how one probability distribution diverges from a second, expected probability distribution

Sources:
https://machinelearningmastery.com/how-to-evaluate-generative-adversarial-networks/
https://towardsdatascience.com/on-the-evaluation-of-generative-adversarial-networks-b056ddcdfd3a

Selected Evaluation Metrics

To evaluate our GAN models, we will be using:

  1. Manual Visual Inspection
  2. FID Score
  3. KL Divergance

Fréchet Inception Distance (FID) Score

Fréchet Inception Distance (FID) evaluates the quality of generated images by calculating the distance between feature vectors calculated for real and generated images.

The FID score summarizes the similarity between the real and fake images in terms of statistics on computer vision features of the raw images, calculated by feature extractors. The most common feature extractor is the Inception-v3 classifier, which is pre-trained on ImageNet. By excluding the output layer, we extract the feature maps from the embeddings of the real and fake images. These embeddings are two multivariate normal distributions, which is compared using Wasserstein-2 distance.

\begin{aligned} FID &= \left\| \mu_r - \mu_g \right\|^2 + \text{Tr}\left(\Sigma_r + \Sigma_g - 2\left(\Sigma_r \Sigma_g\right)^{\frac{1}{2}}\right)\\ \text{where}\\ \mu_r &\text{ is the feature-wise mean of the real images.} \\ \mu_g &\text{ is the feature-wise mean of the generated images.} \\ \Sigma_r &\text{ is the covariance matrix of the real images.} \\ \Sigma_g &\text{ is the covariance matrix of the generated images.} \\ \text{Tr} &\text{ denotes the trace of a matrix, which is the sum of all the diagonal elements.} \\ \end{aligned}

It is to note that FID has its downsides. It uses a pre-trained Inception model, which may not capture all features, hence introducing biasness depending if the training data differs greatly from the domain of generated images. Moreover, it needs a large dataset to be accurate as it uses limited statistics of only mean and convariance.

Sources:
https://machinelearningmastery.com/how-to-implement-the-frechet-inception-distance-fid-from-scratch/
https://www.oreilly.com/library/view/generative-adversarial-networks/9781789136678/9bf2e543-8251-409e-a811-77e55d0dc021.xhtml
https://www.techtarget.com/searchenterpriseai/definition/Frechet-inception-distance-FID

In [ ]:
def calcFID(input_images, num_images=1000):
    inceptionModel = InceptionV3(
        include_top=False,
        weights="imagenet",
        pooling='avg',
    )

    def scale_images(images, new_shape):
        images_list = []
        for image in images:
            new_image = resize(image, new_shape, anti_aliasing=True)
            images_list.append(new_image)
        return np.array(images_list)

    def calculate_fid(model, images1, images2):
        act1 = model.predict(images1)
        act2 = model.predict(images2)
        mu1, sigma1 = act1.mean(axis=0), np.cov(act1, rowvar=False)
        mu2, sigma2 = act2.mean(axis=0), np.cov(act2, rowvar=False)
        ssdiff = np.sum((mu1 - mu2)**2.0)
        covmean = sqrtm(sigma1.dot(sigma2))
        if np.iscomplexobj(covmean):
            covmean = covmean.real
        fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
        return fid

    (real_images, _), (_, _) = tf.keras.datasets.cifar10.load_data()
    np.random.shuffle(real_images)

    real_images = real_images[:num_images]
    real_images = real_images.astype('float32')
    # print(real_images)
    real_images = (real_images / 127.5 - 1)
    real_images = scale_images(real_images, (299, 299))

    generated_images = input_images.astype('float32')
    generated_images = scale_images(generated_images, (299, 299))

    fid = calculate_fid(inceptionModel, real_images, generated_images)
    return fid

Kullback-Leibler (KL) Divergence

To help us further determine the quality of the model, we shall make use of the Kullback-Leibler Divergence metric (KL Divergence) to help us quantitatively evaluate the results of the model. KL Divergence is a statistical measure that quantifies how different one probability distribution is from another reference probability distribution. It is also known as relative entropy. KL Divergence is non-negative and asymmetric, meaning that the divergence of P from Q is not the same as the divergence of Q from P. It is often used in the field of machine learning to measure the difference between the predicted and true probability distributions of data, or to compare a model's distribution with the empirical distribution of the data. KL-divergence can be used within the generator loss function to encourage diverse outputs. This typically involves calculating the KL-divergence between the generated data distribution and a desired target distribution, like a uniform distribution over the data space.

The formula for KL Divergence is as follows: \begin{align*} &D_{KL}(P \parallel Q) = \sum_{i} P(i) \log \left( \frac{P(i)}{Q(i)} \right) \\ &\text{where}\\ &P \text{ is the true distribution} \\ &Q \text{ is the distribution to compare against} \\ &\sum_{i} \text{ is taken over all possible events} \end{align*}

This equation essentially sums up the product of the probabilities from the true distribution ( P(i) ) and the logarithm of the ratio of probabilities from the true distribution to the comparison distribution ( Q(i) ). It is a measure of the information gained about ( P ) when one uses ( Q ) as the approximation. For our use case, we would want to see a low KL Divergence score, as it means that the distribution of the generated data is very close to the distribution of the real data. This is the goal of a well-functioning GAN – to generate data that is indistinguishable from real data. It also gives us a metric to directly compare models with, and also provides us with feedback on how the model is performing during training.


Initial Modelling
Understanding Generative Adversarial Networks (GAN)

GANs are an approach to generative modelling using deep learning methods, like CNNs.

GANs train a generative model by approaching the problem as a supervised learning problem with two sub-models:

  1. Generator: Learn to generate plausible new data from random noise
  2. Discriminator: Tries to distinguish real data from fake (generated) data, in other words Binary Classification

Training Process


These two models are trained together in a zero-sum game, until the Disciminator is fooled about half the time.

The Generator generates a batch of samples, and along with real samples from the dataset, are provided to the Discriminator to be classified as real (1) or fake (0).

While one model trains, the other model's weights remain constant, otherwise the Generator would be trying to hit a moving target & might never converge. The training proceeds in alternating periods, where each model take turns training for one or more epochs.

Loss Functions


Loss functions reflect the distance between the distribution of generated data and the distribution of real data.

Through backpropagation, the Discriminator's weights are updated from the discriminator loss to get better at discriminating, while the Generator's weights are updated from the generator loss based on the Discriminator classification, which is how well or not the generated samples fool the Discriminator.

Zero-sum game refers to when the Discriminator successfully identifies the real and fake samples, it is rewarded or no change is needed to the model parameters, whereas the the Generator is penalized with large updates to the model parameters, and vice versa.

Convergence


At a limit, the Discriminator cannot tell the difference between perfect replicas and the real images, hence predicts "unsure" (e.g. 50% for real and fake). If the GAN continues training with random feedback from the Discriminator, the model might collapse. For a GAN, convergence is often a fleeting, rather than stable, state.

Sources:
https://developers.google.com/machine-learning/gan/gan_structure
https://machinelearningmastery.com/how-to-develop-a-generative-adversarial-network-for-a-cifar-10-small-object-photographs-from-scratch/
https://machinelearningmastery.com/how-to-code-generative-adversarial-network-hacks/

To start tackling the task, we first design a baseline template class, which we can then build off of to create models which use different architectures. As we have already done our preprocessing, we can go straight onto modelling.

To start, we shall use the DCGAN architecture as a baseline, then move onto the cGAN, WGAN, and Hinge GAN architectures. From there, we can pick the best architecture for our task, and hypertune it further to determine our final GAN model to generate images from.

We shall also make use of the Binary Crossentropy loss function to help us here, over other loss functions such as Sparse Categorical Crossentropy, as it is designed for binary classification problems, in which the end goal is to predict one out of two possible outcomes. This fits our task, as in GANs, both the generator and discriminator are involved in a binary classification task, with the discriminator's task being to classify inputs as real or fake, and the generator's job being to generate outputs which are classified as real by the discriminator. The formula for Binary Cross-Entropy (BCE) is as follows: \begin{align*} &BCE = -\frac{1}{N} \sum_{i=1}^{N} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right] \\ &\text{where}\\ &N \text{ is the number of observations,} \\ &y_i \text{ is the actual label of the } i^{\text{th}} \text{ observation,} \\ &\hat{y}_i \text{ is the predicted probability that the } i^{\text{th}} \text{ observation is of the positive class.} \end{align*}

Furthermore, we shall use the Adam optimizer first. ADAptive Moment Estimation is a extension of Stochastic Gradient Descent. It combines ideas from two other optimization algorithms, namely Momentum and RMSProp. It is a good choice for us to start with, given it's "self-learning" properties, which means that it's learning rates are adaptive, relative to how frequently a parameter gets updated while training, making it suitable for problems with sparse gradients or with noisy data. It is also omputationally efficient with relatively low memory requirements, making it suitable for problems with large datasets or parameters.

We create the template class here, as well as implement a custom callback function so that we can track the model's performance easier.

Custom Callback to monitor GAN models' training progress and loss functions.

In [ ]:
class CustomCallback(Callback):
    def __init__(self, d_losses, g_losses, kl_div, model, filepath):
        super(CustomCallback, self).__init__()
        self.d_losses = d_losses
        self.g_losses = g_losses
        self.kl_div = kl_div
        self.model = model
        self.filepath = filepath

    def on_epoch_end(self, epoch, logs=None):
        gan_model = self.model
        generator = gan_model.generator
        d_loss = logs.get('d_loss')
        g_loss = logs.get('g_loss')
        kl_div = logs.get("kl_divergence")
        self.d_losses = np.array(list(self.d_losses))
        self.g_losses = np.array(list(self.g_losses))
        self.kl_div = np.array(list(self.kl_div))
        self.d_losses = np.append(self.d_losses, d_loss)
        self.g_losses = np.append(self.g_losses, g_loss)
        self.kl_div = np.append(self.kl_div, kl_div)
        generated_images, generated_labels = gan_model.generate_fake_samples(self.model, generator = generator)
        self.model.save_plot(generated_images, epoch, self.d_losses, self.g_losses, self.kl_div, self.filepath)
        self.model.save_weights(f"{self.filepath}weights/weights_{epoch}.h5")

GAN Template Class to perform inheritance

In [ ]:
class GAN_template(Model):
    def __init__(self, latent_dim):
        super().__init__()
        self.discriminator = self.define_discriminator()
        self.generator = self.define_generator(latent_dim)
        self.latent_dim = latent_dim
        self.d_loss_tracker = Mean(name="d_loss")
        self.g_loss_tracker = Mean(name="g_loss")
        self.kl_divergence_tracker = Mean(name = "kl_divergence")
        self.g_loss_list = []
        self.d_loss_list = []
        self.kl_div_list = []

    @staticmethod
    def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
        print(examples)
        fig = plt.figure(figsize=(15, 10))
        gs = fig.add_gridspec(4, 6, height_ratios=[1, 1, 1, 1.2], width_ratios=[1, 1, 1, 1, 1, 1], hspace=0.4, wspace=0.4)
        examples = (examples + 1) / 2.0
        for i in range(3 * 6):
            ax = fig.add_subplot(gs[i // 6, i % 6])
            ax.axis('off')
            ax.imshow(examples[i])
        ax_loss = fig.add_subplot(gs[3, 0:2])
        ax_loss.plot(d_losses, label="Discriminator Loss")
        ax_loss.set_title("Discriminator Loss")
        ax_g_loss = fig.add_subplot(gs[3, 2:4])
        ax_g_loss.plot(g_losses, label="Generator Loss")
        ax_g_loss.set_title("Generator Loss")
        ax_kl_div = fig.add_subplot(gs[3, 4:6])
        ax_kl_div.plot(kl_div, label="KL Divergence")
        ax_kl_div.set_title("KL Divergence")
        plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.92)
        plt.tight_layout()
        plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
        plt.close()

    def kl_divergence(self, real_data, generated_data):
        epsilon = 1e-10
        # real_data_flat = tf.reshape(real_data, [real_data.shape[0], -1])
        # generated_data_flat = tf.reshape(generated_data, [generated_data.shape[0], -1])
        real_data_hist = tf.histogram_fixed_width(real_data, [0, 1], nbins=256)
        generated_data_hist = tf.histogram_fixed_width(generated_data, [0, 1], nbins=256)
        real_data_prob = real_data_hist / tf.reduce_sum(real_data_hist)
        generated_data_prob = generated_data_hist / tf.reduce_sum(generated_data_hist)
        epsilon = 1e-10
        real_data_prob += epsilon
        generated_data_prob += epsilon
        kl_div = tf.reduce_sum(real_data_prob * tf.math.log(real_data_prob / generated_data_prob+epsilon))
        return kl_div

    @staticmethod
    def generate_fake_samples(self, generator, n_samples=18, latent_dim=100):
        x_input = np.random.randn(latent_dim * n_samples)
        x_input = x_input.reshape(n_samples, latent_dim)
        X = generator.predict(x_input, verbose=0)
        y = np.zeros((n_samples, 1))
        return X, y

    def define_discriminator(self, in_shape=(32,32,3)):
        pass

    def define_generator(self, latent_dim):
        pass

    def compile(self, d_optimizer, g_optimizer, loss_fn):
        super().compile()
        self.d_optimizer = d_optimizer
        self.g_optimizer = g_optimizer
        self.loss_fn = loss_fn

    def train_step(self, real_images_dataset):
        for real_images in real_images_dataset:
            if isinstance(real_images, tuple):
                real_images = real_images[0]
            batch_size = real_images[0].shape[0]
            random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
            generated_images = self.generator(random_latent_vectors)
            combined_images = tf.concat([generated_images, tf.cast(real_images, tf.float32)], axis=0)
            labels = tf.concat([tf.ones((batch_size, 1)), tf.zeros((batch_size, 1))], axis=0)
            labels += 0.05 * tf.random.uniform(tf.shape(labels))
            with tf.GradientTape() as tape:
                predictions = self.discriminator(combined_images)
                d_loss = self.loss_fn(labels, predictions)
            grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(
                zip(grads, self.discriminator.trainable_weights)
            )
            misleading_labels = tf.zeros((batch_size, 1))
            with tf.GradientTape() as tape:
                generated_images = self.generator(random_latent_vectors)
                predictions = self.discriminator(generated_images)
                g_loss = self.loss_fn(misleading_labels, predictions)
                kl_loss = self.kl_divergence(real_images, generated_images)
            grads = tape.gradient(g_loss, self.generator.trainable_weights)
            self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
    
            # Update metrics and return their value.
            self.d_loss_tracker.update_state(d_loss)
            self.g_loss_tracker.update_state(g_loss)
            self.kl_divergence_tracker.update_state(kl_loss)
    
            return {
                "d_loss": self.d_loss_tracker.result(),
                "g_loss": self.g_loss_tracker.result(),
                "kl_divergence": self.kl_divergence_tracker.result()
            }

DCGAN

DCGAN uses convolutional and convolutional-transpose layers in the generator and discriminator, respectively. It was proposed by Radford et. al. in the paper Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks.

The Discriminator consists of strided convolution layers of 2x2 to downsample the input image, batch normalization layers, and LeakyRelu as activation function. We have replaced pooling layers with strided convolutions, as a strided convolution can decrease the dimension by jumping multiple pixels between convolutions instead of sliding the kernel one-by-one. The discriminator takes a 3x64x64 input image. The discriminator is trained to minimize the binary cross entropy loss function, which is suitable for binary classification.


The Generator consists of convolutional-transpose layers, batch normalization layers, and ReLU activations. The output will be a 3x64x64 RGB image.


Other key features of DCGAN's include the use of ReLU activation functions in the generator (except for the output layer which uses Tanh), and the elimination of fully connected layers and directly connect the output to the convoluational layers where possible.


DCGANs have been used in various applications like photo editing, art creation, image super-resolution, and more. They are particularly noted for their ability to generate high-quality images and learn hierarchical representations of objects in images.

Insights:

In [ ]:
class DCGAN(GAN_template):
    def __init__(self, latent_dim):
        super().__init__(latent_dim)

    def define_discriminator(self, in_shape=(32,32,3)):
        model = Sequential()
        model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
        model.add(LeakyReLU(alpha=0.2))
        # Downsample
        model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # Downsample
        model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # Downsample
        model.add(Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.0015)))
        model.add(LeakyReLU(alpha=0.2))
        # Classifier
        model.add(Flatten())
        model.add(Dropout(0.5))
        model.add(Dense(1, activation='sigmoid'))
        # Compile Model
        model.compile(loss='binary_crossentropy', optimizer = Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
        return model

    def define_generator(self, latent_dim):
        model = Sequential()
        # Nodes to represent a low-resolution version of the output image
        n_nodes = 256 * 4 * 4
        model.add(Dense(n_nodes, input_dim=latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Reshape((4, 4, 256))) # Activations from these nodes can then be reshaped into something image-like, e.g. 256 different 4 x 4 feature maps
        # Upsample
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')) # Combines UpSampling & Conv2D layers, stride of 2x2 quadruples area of the input feature maps
        model.add(LeakyReLU(alpha=0.2))
        # Upsample
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # Upsample
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2D(3, (3,3), activation='tanh', padding='same')) # Three filters for three color channels
        return model

dcgan = DCGAN(latent_dim=100)
dcgan.compile(
    d_optimizer=Adam(learning_rate=0.0003),
    g_optimizer=Adam(learning_rate=0.0003),
    loss_fn=BinaryCrossentropy(from_logits=True),
)
dcgan_callback = CustomCallback(d_losses = dcgan.d_loss_list, g_losses = dcgan.g_loss_list, kl_div = dcgan.kl_div_list,model = dcgan, filepath = "output/models/dcgan/")
dcgan.fit(X_train, epochs = 200, callbacks = [dcgan_callback])
Epoch 1/200
1562/1562 [==============================] - 37s 23ms/step - d_loss: 0.4620 - g_loss: 2.9281 - kl_divergence: 0.9407
Epoch 2/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5346 - g_loss: 1.7597 - kl_divergence: 0.5963
Epoch 3/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5822 - g_loss: 1.4851 - kl_divergence: 0.4541
Epoch 4/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5950 - g_loss: 1.5685 - kl_divergence: 0.5027
Epoch 5/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5956 - g_loss: 1.4207 - kl_divergence: 0.4210
Epoch 6/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5994 - g_loss: 1.3362 - kl_divergence: 0.4004
Epoch 7/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6078 - g_loss: 1.5065 - kl_divergence: 0.4278
Epoch 8/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5852 - g_loss: 1.3864 - kl_divergence: 0.4132
Epoch 9/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6130 - g_loss: 1.2438 - kl_divergence: 0.3762
Epoch 10/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.6243 - g_loss: 1.2382 - kl_divergence: 0.3712
Epoch 11/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6193 - g_loss: 1.1890 - kl_divergence: 0.3827
Epoch 12/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6070 - g_loss: 1.3629 - kl_divergence: 0.3747
Epoch 13/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6219 - g_loss: 1.2310 - kl_divergence: 0.3725
Epoch 14/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.6071 - g_loss: 1.2327 - kl_divergence: 0.3730
Epoch 15/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6028 - g_loss: 1.3161 - kl_divergence: 0.3580
Epoch 16/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6075 - g_loss: 1.2134 - kl_divergence: 0.3661
Epoch 17/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6232 - g_loss: 1.1583 - kl_divergence: 0.3574
Epoch 18/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6021 - g_loss: 1.3342 - kl_divergence: 0.3630
Epoch 19/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6132 - g_loss: 1.2154 - kl_divergence: 0.3565
Epoch 20/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6152 - g_loss: 1.2352 - kl_divergence: 0.3532
Epoch 21/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6013 - g_loss: 1.3076 - kl_divergence: 0.3552
Epoch 22/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6252 - g_loss: 1.3256 - kl_divergence: 0.3809
Epoch 23/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5911 - g_loss: 1.2972 - kl_divergence: 0.3654
Epoch 24/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5913 - g_loss: 1.3727 - kl_divergence: 0.3902
Epoch 25/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5696 - g_loss: 1.4636 - kl_divergence: 0.3603
Epoch 26/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5680 - g_loss: 1.4073 - kl_divergence: 0.3575
Epoch 27/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5791 - g_loss: 1.3054 - kl_divergence: 0.3801
Epoch 28/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5646 - g_loss: 1.3885 - kl_divergence: 0.3635
Epoch 29/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5773 - g_loss: 1.3255 - kl_divergence: 0.3781
Epoch 30/200
1562/1562 [==============================] - 37s 23ms/step - d_loss: 0.5655 - g_loss: 1.3874 - kl_divergence: 0.3788
Epoch 31/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5847 - g_loss: 1.3226 - kl_divergence: 0.3655
Epoch 32/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5868 - g_loss: 1.3944 - kl_divergence: 0.3694
Epoch 33/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5600 - g_loss: 1.4091 - kl_divergence: 0.3805
Epoch 34/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5813 - g_loss: 1.2923 - kl_divergence: 0.3769
Epoch 35/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5915 - g_loss: 1.2569 - kl_divergence: 0.3660
Epoch 36/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5841 - g_loss: 1.2796 - kl_divergence: 0.3703
Epoch 37/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5931 - g_loss: 1.2333 - kl_divergence: 0.3721
Epoch 38/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5917 - g_loss: 1.2376 - kl_divergence: 0.3758
Epoch 39/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.6021 - g_loss: 1.2194 - kl_divergence: 0.3736
Epoch 40/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5958 - g_loss: 1.2724 - kl_divergence: 0.3665
Epoch 41/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5923 - g_loss: 1.2394 - kl_divergence: 0.3566
Epoch 42/200
1562/1562 [==============================] - 37s 23ms/step - d_loss: 0.6100 - g_loss: 1.1156 - kl_divergence: 0.3535
Epoch 43/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6049 - g_loss: 1.1759 - kl_divergence: 0.3605
Epoch 44/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6048 - g_loss: 1.2029 - kl_divergence: 0.3526
Epoch 45/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5992 - g_loss: 1.2041 - kl_divergence: 0.3490
Epoch 46/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.6130 - g_loss: 1.1736 - kl_divergence: 0.3476
Epoch 47/200
1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.6043 - g_loss: 1.1408 - kl_divergence: 0.3466
Epoch 48/200
1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5928 - g_loss: 1.2183 - kl_divergence: 0.3482
Epoch 49/200
1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.6014 - g_loss: 1.1852 - kl_divergence: 0.3518
Epoch 50/200
1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5988 - g_loss: 1.1598 - kl_divergence: 0.3476
Epoch 51/200
1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5857 - g_loss: 1.2534 - kl_divergence: 0.3535
Epoch 52/200
1562/1562 [==============================] - 32s 21ms/step - d_loss: 0.5901 - g_loss: 1.2218 - kl_divergence: 0.3512
Epoch 53/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5925 - g_loss: 1.2221 - kl_divergence: 0.3480
Epoch 54/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5746 - g_loss: 1.3319 - kl_divergence: 0.3488
Epoch 55/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5755 - g_loss: 1.2579 - kl_divergence: 0.3497
Epoch 56/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5878 - g_loss: 1.2517 - kl_divergence: 0.3483
Epoch 57/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5866 - g_loss: 1.1762 - kl_divergence: 0.3628
Epoch 58/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5816 - g_loss: 1.2031 - kl_divergence: 0.3559
Epoch 59/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5997 - g_loss: 1.1481 - kl_divergence: 0.3522
Epoch 60/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5970 - g_loss: 1.1498 - kl_divergence: 0.3567
Epoch 61/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6024 - g_loss: 1.1170 - kl_divergence: 0.3555
Epoch 62/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5952 - g_loss: 1.1467 - kl_divergence: 0.3547
Epoch 63/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.6071 - g_loss: 1.1054 - kl_divergence: 0.3600
Epoch 64/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6104 - g_loss: 1.1025 - kl_divergence: 0.3544
Epoch 65/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5994 - g_loss: 1.1217 - kl_divergence: 0.3597
Epoch 66/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6006 - g_loss: 1.1255 - kl_divergence: 0.3495
Epoch 67/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6018 - g_loss: 1.1059 - kl_divergence: 0.3489
Epoch 68/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6009 - g_loss: 1.1285 - kl_divergence: 0.3471
Epoch 69/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.6035 - g_loss: 1.1244 - kl_divergence: 0.3538
Epoch 70/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5891 - g_loss: 1.1846 - kl_divergence: 0.3511
Epoch 71/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5823 - g_loss: 1.1779 - kl_divergence: 0.3523
Epoch 72/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5877 - g_loss: 1.1787 - kl_divergence: 0.3544
Epoch 73/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5797 - g_loss: 1.2198 - kl_divergence: 0.3522
Epoch 74/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5797 - g_loss: 1.1848 - kl_divergence: 0.3537
Epoch 75/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5837 - g_loss: 1.1969 - kl_divergence: 0.3542
Epoch 76/200
1562/1562 [==============================] - 37s 24ms/step - d_loss: 0.5762 - g_loss: 1.2353 - kl_divergence: 0.3533
Epoch 77/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5787 - g_loss: 1.2061 - kl_divergence: 0.3569
Epoch 78/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5753 - g_loss: 1.2357 - kl_divergence: 0.3573
Epoch 79/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5752 - g_loss: 1.2425 - kl_divergence: 0.3561
Epoch 80/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5699 - g_loss: 1.2389 - kl_divergence: 0.3590
Epoch 81/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5881 - g_loss: 1.2102 - kl_divergence: 0.3481
Epoch 82/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5780 - g_loss: 1.2019 - kl_divergence: 0.3618
Epoch 83/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5682 - g_loss: 1.2382 - kl_divergence: 0.3646
Epoch 84/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5718 - g_loss: 1.2492 - kl_divergence: 0.3620
Epoch 85/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5714 - g_loss: 1.2350 - kl_divergence: 0.3624
Epoch 86/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5677 - g_loss: 1.2485 - kl_divergence: 0.3618
Epoch 87/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5691 - g_loss: 1.2545 - kl_divergence: 0.3598
Epoch 88/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5729 - g_loss: 1.2548 - kl_divergence: 0.3557
Epoch 89/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5694 - g_loss: 1.2296 - kl_divergence: 0.3610
Epoch 90/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5667 - g_loss: 1.2560 - kl_divergence: 0.3606
Epoch 91/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5607 - g_loss: 1.2710 - kl_divergence: 0.3625
Epoch 92/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5805 - g_loss: 1.2325 - kl_divergence: 0.3580
Epoch 93/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5596 - g_loss: 1.2847 - kl_divergence: 0.3604
Epoch 94/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5592 - g_loss: 1.2449 - kl_divergence: 0.3561
Epoch 95/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5701 - g_loss: 1.2569 - kl_divergence: 0.3565
Epoch 96/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5673 - g_loss: 1.2507 - kl_divergence: 0.3584
Epoch 97/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5537 - g_loss: 1.2816 - kl_divergence: 0.3558
Epoch 98/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5629 - g_loss: 1.2954 - kl_divergence: 0.3567
Epoch 99/200
1562/1562 [==============================] - 37s 24ms/step - d_loss: 0.5559 - g_loss: 1.3044 - kl_divergence: 0.3575
Epoch 100/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5604 - g_loss: 1.2645 - kl_divergence: 0.3512
Epoch 101/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5596 - g_loss: 1.2886 - kl_divergence: 0.3518
Epoch 102/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5627 - g_loss: 1.2622 - kl_divergence: 0.3646
Epoch 103/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5586 - g_loss: 1.2667 - kl_divergence: 0.3632
Epoch 104/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5575 - g_loss: 1.2881 - kl_divergence: 0.3569
Epoch 105/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5596 - g_loss: 1.2592 - kl_divergence: 0.3560
Epoch 106/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5533 - g_loss: 1.3084 - kl_divergence: 0.3587
Epoch 107/200
1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5541 - g_loss: 1.2693 - kl_divergence: 0.3585
Epoch 108/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5517 - g_loss: 1.3361 - kl_divergence: 0.3595
Epoch 109/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5534 - g_loss: 1.2923 - kl_divergence: 0.3525
Epoch 110/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5522 - g_loss: 1.3019 - kl_divergence: 0.3573
Epoch 111/200
1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5476 - g_loss: 1.3247 - kl_divergence: 0.3599
Epoch 112/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5421 - g_loss: 1.3081 - kl_divergence: 0.3577
Epoch 113/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5627 - g_loss: 1.2644 - kl_divergence: 0.3509
Epoch 114/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5542 - g_loss: 1.2908 - kl_divergence: 0.3541
Epoch 115/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5572 - g_loss: 1.2493 - kl_divergence: 0.3529
Epoch 116/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5610 - g_loss: 1.3004 - kl_divergence: 0.3507
Epoch 117/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5583 - g_loss: 1.2566 - kl_divergence: 0.3528
Epoch 118/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5556 - g_loss: 1.2753 - kl_divergence: 0.3613
Epoch 119/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5610 - g_loss: 1.2382 - kl_divergence: 0.3560
Epoch 120/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5558 - g_loss: 1.2500 - kl_divergence: 0.3642
Epoch 121/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5601 - g_loss: 1.2719 - kl_divergence: 0.3568
Epoch 122/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5512 - g_loss: 1.3111 - kl_divergence: 0.3615
Epoch 123/200
1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5610 - g_loss: 1.2511 - kl_divergence: 0.3686
Epoch 124/200
1562/1562 [==============================] - 34s 22ms/step - d_loss: 0.5603 - g_loss: 1.2874 - kl_divergence: 0.3576
Epoch 125/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5575 - g_loss: 1.2420 - kl_divergence: 0.3533
Epoch 126/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5498 - g_loss: 1.2881 - kl_divergence: 0.3556
Epoch 127/200
1562/1562 [==============================] - 38s 24ms/step - d_loss: 0.5613 - g_loss: 1.2682 - kl_divergence: 0.3564
Epoch 128/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5518 - g_loss: 1.2522 - kl_divergence: 0.3571
Epoch 129/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5522 - g_loss: 1.2916 - kl_divergence: 0.3568
Epoch 130/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5506 - g_loss: 1.2723 - kl_divergence: 0.3549
Epoch 131/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5537 - g_loss: 1.2758 - kl_divergence: 0.3550
Epoch 132/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5510 - g_loss: 1.2763 - kl_divergence: 0.3569
Epoch 133/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5469 - g_loss: 1.3028 - kl_divergence: 0.3553
Epoch 134/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5610 - g_loss: 1.2862 - kl_divergence: 0.3518
Epoch 135/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5464 - g_loss: 1.2623 - kl_divergence: 0.3509
Epoch 136/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5381 - g_loss: 1.3185 - kl_divergence: 0.3563
Epoch 137/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5447 - g_loss: 1.3430 - kl_divergence: 0.3515
Epoch 138/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5464 - g_loss: 1.2760 - kl_divergence: 0.3592
Epoch 139/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5436 - g_loss: 1.3322 - kl_divergence: 0.3521
Epoch 140/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5414 - g_loss: 1.2778 - kl_divergence: 0.3591
Epoch 141/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5434 - g_loss: 1.3041 - kl_divergence: 0.3534
Epoch 142/200
1562/1562 [==============================] - 35s 22ms/step - d_loss: 0.5407 - g_loss: 1.3358 - kl_divergence: 0.3551
Epoch 143/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5375 - g_loss: 1.3016 - kl_divergence: 0.3617
Epoch 144/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5436 - g_loss: 1.3120 - kl_divergence: 0.3521
Epoch 145/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5414 - g_loss: 1.2944 - kl_divergence: 0.3550
Epoch 146/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5414 - g_loss: 1.3288 - kl_divergence: 0.3521
Epoch 147/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5385 - g_loss: 1.3578 - kl_divergence: 0.3557
Epoch 148/200
1562/1562 [==============================] - 36s 23ms/step - d_loss: 0.5399 - g_loss: 1.2914 - kl_divergence: 0.3550
Epoch 149/200
1562/1562 [==============================] - 34s 22ms/step - d_loss: 0.5329 - g_loss: 1.3307 - kl_divergence: 0.3606
Epoch 150/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5351 - g_loss: 1.3280 - kl_divergence: 0.3570
Epoch 151/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5405 - g_loss: 1.3040 - kl_divergence: 0.3544
Epoch 152/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5433 - g_loss: 1.2975 - kl_divergence: 0.3530
Epoch 153/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5412 - g_loss: 1.3091 - kl_divergence: 0.3564
Epoch 154/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5423 - g_loss: 1.3010 - kl_divergence: 0.3527
Epoch 155/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5413 - g_loss: 1.3058 - kl_divergence: 0.3554
Epoch 156/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5329 - g_loss: 1.3560 - kl_divergence: 0.3547
Epoch 157/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5359 - g_loss: 1.3165 - kl_divergence: 0.3545
Epoch 158/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5389 - g_loss: 1.3312 - kl_divergence: 0.3525
Epoch 159/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5350 - g_loss: 1.3174 - kl_divergence: 0.3555
Epoch 160/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5302 - g_loss: 1.3220 - kl_divergence: 0.3563
Epoch 161/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5321 - g_loss: 1.3518 - kl_divergence: 0.3523
Epoch 162/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5321 - g_loss: 1.3550 - kl_divergence: 0.3504
Epoch 163/200
1562/1562 [==============================] - 35s 23ms/step - d_loss: 0.5275 - g_loss: 1.3714 - kl_divergence: 0.3541
Epoch 164/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5200 - g_loss: 1.3458 - kl_divergence: 0.3566
Epoch 165/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5189 - g_loss: 1.3846 - kl_divergence: 0.3608
Epoch 166/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5202 - g_loss: 1.3991 - kl_divergence: 0.3591
Epoch 167/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5233 - g_loss: 1.3678 - kl_divergence: 0.3558
Epoch 168/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5234 - g_loss: 1.3949 - kl_divergence: 0.3547
Epoch 169/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5186 - g_loss: 1.3869 - kl_divergence: 0.3568
Epoch 170/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5158 - g_loss: 1.3899 - kl_divergence: 0.3619
Epoch 171/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5196 - g_loss: 1.4413 - kl_divergence: 0.3572
Epoch 172/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5201 - g_loss: 1.3844 - kl_divergence: 0.3572
Epoch 173/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5183 - g_loss: 1.3875 - kl_divergence: 0.3585
Epoch 174/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5119 - g_loss: 1.4424 - kl_divergence: 0.3580
Epoch 175/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5131 - g_loss: 1.4047 - kl_divergence: 0.3569
Epoch 176/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5080 - g_loss: 1.4314 - kl_divergence: 0.3531
Epoch 177/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5145 - g_loss: 1.4126 - kl_divergence: 0.3548
Epoch 178/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5111 - g_loss: 1.4162 - kl_divergence: 0.3591
Epoch 179/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5136 - g_loss: 1.4060 - kl_divergence: 0.3570
Epoch 180/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5082 - g_loss: 1.4531 - kl_divergence: 0.3605
Epoch 181/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5110 - g_loss: 1.4555 - kl_divergence: 0.3600
Epoch 182/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5081 - g_loss: 1.3883 - kl_divergence: 0.3598
Epoch 183/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5035 - g_loss: 1.4574 - kl_divergence: 0.3585
Epoch 184/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5173 - g_loss: 1.4260 - kl_divergence: 0.3591
Epoch 185/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5080 - g_loss: 1.4357 - kl_divergence: 0.3574
Epoch 186/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5031 - g_loss: 1.4888 - kl_divergence: 0.3591
Epoch 187/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5039 - g_loss: 1.4264 - kl_divergence: 0.3632
Epoch 188/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5063 - g_loss: 1.4129 - kl_divergence: 0.3643
Epoch 189/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5042 - g_loss: 1.4980 - kl_divergence: 0.3597
Epoch 190/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.4995 - g_loss: 1.4302 - kl_divergence: 0.3591
Epoch 191/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5021 - g_loss: 1.4641 - kl_divergence: 0.3610
Epoch 192/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5033 - g_loss: 1.5035 - kl_divergence: 0.3614
Epoch 193/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.4998 - g_loss: 1.4920 - kl_divergence: 0.3562
Epoch 194/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5021 - g_loss: 1.4419 - kl_divergence: 0.3536
Epoch 195/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5064 - g_loss: 1.5291 - kl_divergence: 0.3603
Epoch 196/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5051 - g_loss: 1.4093 - kl_divergence: 0.3593
Epoch 197/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.4958 - g_loss: 1.4799 - kl_divergence: 0.3589
Epoch 198/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5000 - g_loss: 1.4938 - kl_divergence: 0.3601
Epoch 199/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5001 - g_loss: 1.4814 - kl_divergence: 0.3581
Epoch 200/200
1562/1562 [==============================] - 33s 21ms/step - d_loss: 0.5042 - g_loss: 1.5149 - kl_divergence: 0.3614
Out[ ]:
<keras.callbacks.History at 0x229492c46d0>
In [ ]:
class DCGAN(GAN_template):
    def __init__(self, latent_dim):
        super().__init__(latent_dim)

    def define_discriminator(self, in_shape=(32,32,3)):
        model = Sequential()
        model.add(Conv2D(64, (3,3), padding='same', input_shape=in_shape))
        model.add(LeakyReLU(alpha=0.2))
        # Downsample
        model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # Downsample
        model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # Downsample
        model.add(Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.0015)))
        model.add(LeakyReLU(alpha=0.2))
        # Classifier
        model.add(Flatten())
        model.add(Dropout(0.5))
        model.add(Dense(1, activation='sigmoid'))
        # Compile Model
        model.compile(loss='binary_crossentropy', optimizer = Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
        return model

    def define_generator(self, latent_dim):
        model = Sequential()
        # Nodes to represent a low-resolution version of the output image
        n_nodes = 256 * 4 * 4
        model.add(Dense(n_nodes, input_dim=latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Reshape((4, 4, 256))) # Activations from these nodes can then be reshaped into something image-like, e.g. 256 different 4 x 4 feature maps
        # Upsample
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same')) # Combines UpSampling & Conv2D layers, stride of 2x2 quadruples area of the input feature maps
        model.add(LeakyReLU(alpha=0.2))
        # Upsample
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        # Upsample
        model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Conv2D(3, (3,3), activation='tanh', padding='same')) # Three filters for three color channels
        return model
In [ ]:
input_folder = './output/models/dcgan/generated/'  # Replace with your frames directory
output_file = 'output_video.mp4'      # Replace with your desired output file path
gnnf.create_video_from_frames(input_folder, output_file)

From the above, we can see that the images generated by DCGAN are somewhat coherent, with makeshift objects being present in the imgaes. The images also have detailing to them, and some what resemble the images in the CIFAR-10 dataset.

However, we don't know which class of images the model is trying to predict. Hence, we shall try to solve with problem with the next model architecture.


Conditional GAN (cGAN)

Conditional Generative Adversarial Networks (cGANs) represent an advanced evolution in the realm of Generative Adversarial Networks (GANs), specifically designed for generating data samples under defined conditions. The foundational work on cGANs is attributed to Mirza and Osindero in their seminal paper "Conditional Generative Adversarial Nets".

In the architecture of cGANs, the generator and discriminator are both conditioned on additional information, such as labels or tags, which guide the data generation process. This conditional approach allows for the generation of targeted data samples, enhancing the versatility and effectiveness of the network.

Some architectural features of cGAN include the use of conditional information such as labels to help steer the data generation process to align with conditions such as generating images of a class. This information is also given to the discriminator to help access if the generated data aligns with the given conditions. In terms of activation functions, cGANs often employ similar activation techniques to DCGANs.

In [ ]:
num_classes = 10
class cGAN(GAN_template):
    def __init__(self, latent_dim):
        super().__init__(latent_dim)
        self.num_classes = num_classes

    def define_discriminator(self, in_shape=(32,32,3)):
        # Image input
        image_input = Input(shape=in_shape)
        label_input = Input(shape=(1,))
        label_embedding = Embedding(num_classes, np.prod(in_shape))(label_input)
        label_embedding = Dense(np.prod(in_shape))(label_embedding)
        label_embedding = Reshape(in_shape)(label_embedding)
        concatenated = Concatenate()([image_input, label_embedding])
        x = Conv2D(64, (3,3), padding='same')(concatenated)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(64, (3,3), strides=(2,2), padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(128, (3,3), strides=(2,2), padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.001, l2=0.001))(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Flatten()(x)
        x = Dropout(0.4)(x)
        output = Dense(1, activation='sigmoid')(x)

        model = Model(inputs=[image_input, label_input], outputs=output)
        model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
        return model

    def define_generator(self, latent_dim):
        model = Sequential()
            # Model for processing the labels
        label_input = tf.keras.Input(shape=(1,), dtype='int32')
        label_embedding = Embedding(num_classes, latent_dim)(label_input)
        label_embedding = Flatten()(label_embedding)
        latent_input = tf.keras.Input(shape=(latent_dim,))
        merged_input = Concatenate()([latent_input, label_embedding])

        # Sequential model for the generator
        generator = Sequential([
            Dense(8 * 8 * 256, input_shape=(latent_dim * 2,)),
            LeakyReLU(alpha=0.2),
            Reshape((8, 8, 256)),
            Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same'),
            BatchNormalization(),
            LeakyReLU(alpha=0.2),
            Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
            BatchNormalization(),
            LeakyReLU(alpha=0.2),
            Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
        ])

        # Pass the merged input through the generator model
        generated_image = generator(merged_input)

        # Final cGAN generator model
        model = Model(inputs=[latent_input, label_input], outputs=generated_image)
        return model

    def train_step(self, data):
        if isinstance(data, tuple):
               real_images, real_labels = data
        else:
            real_images = data
            real_labels = tf.random.uniform([tf.shape(real_images)[0]], minval=0, maxval=self.num_classes, dtype=tf.int32)

        batch_size = real_images[0].shape[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        real_images = tf.reshape(real_images, [batch_size, 32, 32, 3])
        fake_labels = tf.random.uniform([batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32)
        generated_images = self.generator([random_latent_vectors, fake_labels])
        combined_images = tf.concat([generated_images, tf.cast(real_images, tf.float32)], axis=0)
        real_labels = tf.squeeze(real_labels)
        combined_labels = tf.concat([tf.cast(fake_labels, 'uint8'), real_labels], axis=0)  # Concatenate labels as well
        discriminator_labels = tf.concat(
            [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
        )
        with tf.GradientTape() as tape:
            predictions = self.discriminator([combined_images, combined_labels])
            d_loss = self.loss_fn(discriminator_labels, predictions)
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))
        misleading_labels = tf.ones((batch_size, 1))
        with tf.GradientTape() as tape:
            generated_images = self.generator([random_latent_vectors, fake_labels])
            predictions = self.discriminator([generated_images, fake_labels])
            g_loss = self.loss_fn(misleading_labels, predictions)
            kl_loss = self.kl_divergence(real_images, generated_images)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        self.kl_divergence_tracker.update_state(kl_loss)

        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
            "kl_divergence": self.kl_divergence_tracker.result()
        }

    @staticmethod
    def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
        fig = plt.figure(figsize=(20, 15))
        gs = fig.add_gridspec(10, 10, height_ratios=[1]*10, width_ratios=[1]*10, hspace=0.25, wspace=0.2)
        examples = (examples + 1) / 2.0
        class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

        for i in range(10*5):  # 3 images per class, 10 classes
            class_index = i // 5  # Determine class based on order
            ax = fig.add_subplot(gs[i % 5, class_index])
            ax.axis('off')
            ax.imshow(examples[i])
            if i % 5 == 0:
                ax.set_title(class_names[class_index], fontsize=8)

        # Plot for discriminator losses
        ax_loss = fig.add_subplot(gs[5:8, 0:3])
        ax_loss.plot(d_losses, label="Discriminator Loss")
        ax_loss.set_title("Discriminator Loss")

        # Plot for generator losses
        ax_g_loss = fig.add_subplot(gs[5:8, 3:7])
        ax_g_loss.plot(g_losses, label="Generator Loss")
        ax_g_loss.set_title("Generator Loss")

        ax_kl_div = fig.add_subplot(gs[5:8, 7:10])
        ax_kl_div.plot(kl_div, label="KL Divergence")
        ax_kl_div.set_title("KL Divergence")

        plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.95)
        plt.tight_layout()
        plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
        plt.close()

    @staticmethod
    def generate_fake_samples(self, generator, n_samples=5, latent_dim=100):
        X, y = [], []
        for class_label in range(10):  # CIFAR-10 has 10 classes
            # Generate latent points
            x_input = np.random.randn(latent_dim * n_samples)
            x_input = x_input.reshape(n_samples, latent_dim)
            # Create class labels
            labels = np.full((n_samples, 1), class_label)
            # Generate images
            images = generator.predict([x_input, labels], verbose=0)
            X.extend(images)
            y.extend(labels)
        return np.asarray(X), np.asarray(y)

cgan = cGAN(latent_dim=100)
cgan.compile(
    d_optimizer=Adam(learning_rate=0.0003),
    g_optimizer=Adam(learning_rate=0.0003),
    loss_fn=BinaryCrossentropy(from_logits=True),
)
cgan_callback = CustomCallback(d_losses = cgan.d_loss_list, g_losses = cgan.g_loss_list, kl_div=cgan.kl_div_list, model = cgan, filepath = "output/models/cgan/")
cgan.fit(X_train, epochs = 200, callbacks = [cgan_callback])
Epoch 1/200
1562/1562 [==============================] - 58s 35ms/step - d_loss: 0.4840 - g_loss: 2.5247 - kl_divergence: 0.6751
Epoch 2/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.5458 - g_loss: 1.5353 - kl_divergence: 0.4259
Epoch 3/200
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.5698 - g_loss: 1.4624 - kl_divergence: 0.4204
Epoch 4/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5858 - g_loss: 1.4409 - kl_divergence: 0.3989
Epoch 5/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.6061 - g_loss: 1.2307 - kl_divergence: 0.3679
Epoch 6/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.6118 - g_loss: 1.5189 - kl_divergence: 0.4050
Epoch 7/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5794 - g_loss: 1.3202 - kl_divergence: 0.3703
Epoch 8/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5902 - g_loss: 1.3091 - kl_divergence: 0.3445
Epoch 9/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.5745 - g_loss: 1.5144 - kl_divergence: 0.3493
Epoch 10/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5347 - g_loss: 1.4671 - kl_divergence: 0.3483
Epoch 11/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.5201 - g_loss: 1.5850 - kl_divergence: 0.3519
Epoch 12/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.5057 - g_loss: 1.6061 - kl_divergence: 0.3539
Epoch 13/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4607 - g_loss: 1.8370 - kl_divergence: 0.3662
Epoch 14/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4397 - g_loss: 1.9564 - kl_divergence: 0.3696
Epoch 15/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4052 - g_loss: 2.0620 - kl_divergence: 0.3698
Epoch 16/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3994 - g_loss: 2.1406 - kl_divergence: 0.3779
Epoch 17/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.4285 - g_loss: 1.8424 - kl_divergence: 0.3678
Epoch 18/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4436 - g_loss: 1.7842 - kl_divergence: 0.3660
Epoch 19/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4729 - g_loss: 1.5721 - kl_divergence: 0.3576
Epoch 20/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4682 - g_loss: 1.5564 - kl_divergence: 0.3533
Epoch 21/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4808 - g_loss: 1.5065 - kl_divergence: 0.3470
Epoch 22/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4863 - g_loss: 1.4848 - kl_divergence: 0.3475
Epoch 23/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.5005 - g_loss: 1.4360 - kl_divergence: 0.3517
Epoch 24/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4891 - g_loss: 1.4734 - kl_divergence: 0.3485
Epoch 25/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4824 - g_loss: 1.5402 - kl_divergence: 0.3463
Epoch 26/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4841 - g_loss: 1.5289 - kl_divergence: 0.3442
Epoch 27/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4574 - g_loss: 1.5744 - kl_divergence: 0.3450
Epoch 28/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4942 - g_loss: 1.5094 - kl_divergence: 0.3430
Epoch 29/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4561 - g_loss: 1.5703 - kl_divergence: 0.3419
Epoch 30/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.4706 - g_loss: 1.5855 - kl_divergence: 0.3416
Epoch 31/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4739 - g_loss: 1.5134 - kl_divergence: 0.3412
Epoch 32/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4813 - g_loss: 1.5302 - kl_divergence: 0.3423
Epoch 33/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4710 - g_loss: 1.4704 - kl_divergence: 0.3450
Epoch 34/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4941 - g_loss: 1.4983 - kl_divergence: 0.3448
Epoch 35/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4807 - g_loss: 1.4557 - kl_divergence: 0.3456
Epoch 36/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4781 - g_loss: 1.4953 - kl_divergence: 0.3457
Epoch 37/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4887 - g_loss: 1.4842 - kl_divergence: 0.3469
Epoch 38/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4851 - g_loss: 1.5009 - kl_divergence: 0.3402
Epoch 39/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4746 - g_loss: 1.4589 - kl_divergence: 0.3485
Epoch 40/200
1562/1562 [==============================] - 53s 34ms/step - d_loss: 0.4696 - g_loss: 1.5598 - kl_divergence: 0.3434
Epoch 41/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4795 - g_loss: 1.4673 - kl_divergence: 0.3431
Epoch 42/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4643 - g_loss: 1.5405 - kl_divergence: 0.3409
Epoch 43/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4746 - g_loss: 1.5465 - kl_divergence: 0.3392
Epoch 44/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4611 - g_loss: 1.5420 - kl_divergence: 0.3384
Epoch 45/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4592 - g_loss: 1.5817 - kl_divergence: 0.3415
Epoch 46/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4678 - g_loss: 1.6379 - kl_divergence: 0.3394
Epoch 47/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4598 - g_loss: 1.5374 - kl_divergence: 0.3413
Epoch 48/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.4592 - g_loss: 1.6224 - kl_divergence: 0.3402
Epoch 49/200
1562/1562 [==============================] - 54s 34ms/step - d_loss: 0.4636 - g_loss: 1.5936 - kl_divergence: 0.3361
Epoch 50/200
1562/1562 [==============================] - 53s 34ms/step - d_loss: 0.4593 - g_loss: 1.5885 - kl_divergence: 0.3355
Epoch 51/200
1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.4506 - g_loss: 1.6277 - kl_divergence: 0.3372
Epoch 52/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4474 - g_loss: 1.6525 - kl_divergence: 0.3379
Epoch 53/200
1562/1562 [==============================] - 54s 34ms/step - d_loss: 0.4449 - g_loss: 1.6906 - kl_divergence: 0.3371
Epoch 54/200
1562/1562 [==============================] - 54s 35ms/step - d_loss: 0.4386 - g_loss: 1.6731 - kl_divergence: 0.3436
Epoch 55/200
1562/1562 [==============================] - 54s 35ms/step - d_loss: 0.4413 - g_loss: 1.7110 - kl_divergence: 0.3354
Epoch 56/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4303 - g_loss: 1.7006 - kl_divergence: 0.3370
Epoch 57/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4274 - g_loss: 1.7386 - kl_divergence: 0.3360
Epoch 58/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4260 - g_loss: 1.7518 - kl_divergence: 0.3392
Epoch 59/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4263 - g_loss: 1.7579 - kl_divergence: 0.3404
Epoch 60/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4264 - g_loss: 1.7563 - kl_divergence: 0.3375
Epoch 61/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4216 - g_loss: 1.8059 - kl_divergence: 0.3370
Epoch 62/200
1562/1562 [==============================] - 55s 36ms/step - d_loss: 0.4181 - g_loss: 1.7909 - kl_divergence: 0.3396
Epoch 63/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4194 - g_loss: 1.8367 - kl_divergence: 0.3362
Epoch 64/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4147 - g_loss: 1.8420 - kl_divergence: 0.3386
Epoch 65/200
1562/1562 [==============================] - 58s 37ms/step - d_loss: 0.4142 - g_loss: 1.8307 - kl_divergence: 0.3376
Epoch 66/200
1562/1562 [==============================] - 56s 35ms/step - d_loss: 0.4173 - g_loss: 1.8338 - kl_divergence: 0.3382
Epoch 67/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4117 - g_loss: 1.8508 - kl_divergence: 0.3372
Epoch 68/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.4075 - g_loss: 1.8894 - kl_divergence: 0.3358
Epoch 69/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4035 - g_loss: 1.8991 - kl_divergence: 0.3349
Epoch 70/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.4053 - g_loss: 1.8909 - kl_divergence: 0.3370
Epoch 71/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3993 - g_loss: 1.9298 - kl_divergence: 0.3385
Epoch 72/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3977 - g_loss: 1.9482 - kl_divergence: 0.3382
Epoch 73/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3979 - g_loss: 1.9846 - kl_divergence: 0.3388
Epoch 74/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3983 - g_loss: 1.9537 - kl_divergence: 0.3379
Epoch 75/200
1562/1562 [==============================] - 54s 35ms/step - d_loss: 0.3905 - g_loss: 1.9815 - kl_divergence: 0.3371
Epoch 76/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3907 - g_loss: 2.0003 - kl_divergence: 0.3386
Epoch 77/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3883 - g_loss: 1.9946 - kl_divergence: 0.3390
Epoch 78/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3849 - g_loss: 2.0377 - kl_divergence: 0.3378
Epoch 79/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3852 - g_loss: 2.0344 - kl_divergence: 0.3414
Epoch 80/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3842 - g_loss: 2.0532 - kl_divergence: 0.3396
Epoch 81/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3803 - g_loss: 2.0818 - kl_divergence: 0.3402
Epoch 82/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3810 - g_loss: 2.0800 - kl_divergence: 0.3394
Epoch 83/200
1562/1562 [==============================] - 59s 38ms/step - d_loss: 0.3790 - g_loss: 2.0881 - kl_divergence: 0.3389
Epoch 84/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3766 - g_loss: 2.1098 - kl_divergence: 0.3380
Epoch 85/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3734 - g_loss: 2.1341 - kl_divergence: 0.3416
Epoch 86/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3699 - g_loss: 2.1488 - kl_divergence: 0.3401
Epoch 87/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3713 - g_loss: 2.1559 - kl_divergence: 0.3403
Epoch 88/200
1562/1562 [==============================] - 54s 34ms/step - d_loss: 0.3664 - g_loss: 2.1840 - kl_divergence: 0.3409
Epoch 89/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3697 - g_loss: 2.1760 - kl_divergence: 0.3404
Epoch 90/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3644 - g_loss: 2.1908 - kl_divergence: 0.3416
Epoch 91/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.3671 - g_loss: 2.2437 - kl_divergence: 0.3387
Epoch 92/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3607 - g_loss: 2.2272 - kl_divergence: 0.3440
Epoch 93/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3568 - g_loss: 2.2641 - kl_divergence: 0.3428
Epoch 94/200
1562/1562 [==============================] - 53s 34ms/step - d_loss: 0.3582 - g_loss: 2.2703 - kl_divergence: 0.3453
Epoch 95/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3574 - g_loss: 2.2715 - kl_divergence: 0.3436
Epoch 96/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3563 - g_loss: 2.2839 - kl_divergence: 0.3463
Epoch 97/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3569 - g_loss: 2.3054 - kl_divergence: 0.3418
Epoch 98/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3517 - g_loss: 2.3075 - kl_divergence: 0.3440
Epoch 99/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3490 - g_loss: 2.3251 - kl_divergence: 0.3416
Epoch 100/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3461 - g_loss: 2.3574 - kl_divergence: 0.3486
Epoch 101/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3450 - g_loss: 2.3640 - kl_divergence: 0.3458
Epoch 102/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3435 - g_loss: 2.3998 - kl_divergence: 0.3447
Epoch 103/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3401 - g_loss: 2.4136 - kl_divergence: 0.3430
Epoch 104/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3394 - g_loss: 2.4441 - kl_divergence: 0.3462
Epoch 105/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 0.3380 - g_loss: 2.4317 - kl_divergence: 0.3426
Epoch 106/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3369 - g_loss: 2.4472 - kl_divergence: 0.3467
Epoch 107/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3329 - g_loss: 2.4941 - kl_divergence: 0.3446
Epoch 108/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3307 - g_loss: 2.4686 - kl_divergence: 0.3431
Epoch 109/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3322 - g_loss: 2.5111 - kl_divergence: 0.3408
Epoch 110/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3535 - g_loss: 2.5937 - kl_divergence: 0.3419
Epoch 111/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3345 - g_loss: 2.4697 - kl_divergence: 0.3422
Epoch 112/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3301 - g_loss: 2.4993 - kl_divergence: 0.3429
Epoch 113/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3268 - g_loss: 2.5677 - kl_divergence: 0.3439
Epoch 114/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3242 - g_loss: 2.5815 - kl_divergence: 0.3412
Epoch 115/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3227 - g_loss: 2.5716 - kl_divergence: 0.3407
Epoch 116/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3230 - g_loss: 2.6080 - kl_divergence: 0.3451
Epoch 117/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3246 - g_loss: 2.6098 - kl_divergence: 0.3413
Epoch 118/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3198 - g_loss: 2.6258 - kl_divergence: 0.3451
Epoch 119/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3157 - g_loss: 2.6524 - kl_divergence: 0.3459
Epoch 120/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3141 - g_loss: 2.6667 - kl_divergence: 0.3419
Epoch 121/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3142 - g_loss: 2.6944 - kl_divergence: 0.3455
Epoch 122/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3134 - g_loss: 2.7040 - kl_divergence: 0.3468
Epoch 123/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3127 - g_loss: 2.7053 - kl_divergence: 0.3430
Epoch 124/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3095 - g_loss: 2.7472 - kl_divergence: 0.3472
Epoch 125/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3061 - g_loss: 2.7704 - kl_divergence: 0.3410
Epoch 126/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3128 - g_loss: 2.8328 - kl_divergence: 0.3419
Epoch 127/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3044 - g_loss: 2.7802 - kl_divergence: 0.3472
Epoch 128/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3051 - g_loss: 2.7825 - kl_divergence: 0.3439
Epoch 129/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3062 - g_loss: 2.7921 - kl_divergence: 0.3428
Epoch 130/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3003 - g_loss: 2.8314 - kl_divergence: 0.3449
Epoch 131/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3008 - g_loss: 2.8626 - kl_divergence: 0.3458
Epoch 132/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.3001 - g_loss: 2.8366 - kl_divergence: 0.3452
Epoch 133/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.2980 - g_loss: 2.8546 - kl_divergence: 0.3435
Epoch 134/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2993 - g_loss: 2.8737 - kl_divergence: 0.3431
Epoch 135/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2981 - g_loss: 2.8804 - kl_divergence: 0.3456
Epoch 136/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.3016 - g_loss: 2.8909 - kl_divergence: 0.3435
Epoch 137/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2963 - g_loss: 2.8966 - kl_divergence: 0.3467
Epoch 138/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2950 - g_loss: 2.9014 - kl_divergence: 0.3451
Epoch 139/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2954 - g_loss: 2.9124 - kl_divergence: 0.3446
Epoch 140/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2930 - g_loss: 2.9481 - kl_divergence: 0.3474
Epoch 141/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2938 - g_loss: 2.9328 - kl_divergence: 0.3456
Epoch 142/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2876 - g_loss: 2.9911 - kl_divergence: 0.3445
Epoch 143/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2889 - g_loss: 3.0045 - kl_divergence: 0.3457
Epoch 144/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2884 - g_loss: 2.9893 - kl_divergence: 0.3494
Epoch 145/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2863 - g_loss: 3.0292 - kl_divergence: 0.3454
Epoch 146/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2811 - g_loss: 3.0429 - kl_divergence: 0.3468
Epoch 147/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2841 - g_loss: 3.0720 - kl_divergence: 0.3456
Epoch 148/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2819 - g_loss: 3.0755 - kl_divergence: 0.3500
Epoch 149/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2839 - g_loss: 3.0660 - kl_divergence: 0.3517
Epoch 150/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2810 - g_loss: 3.0974 - kl_divergence: 0.3483
Epoch 151/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2808 - g_loss: 3.1015 - kl_divergence: 0.3474
Epoch 152/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2794 - g_loss: 3.1128 - kl_divergence: 0.3438
Epoch 153/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2734 - g_loss: 3.1639 - kl_divergence: 0.3466
Epoch 154/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2796 - g_loss: 3.1662 - kl_divergence: 0.3450
Epoch 155/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2796 - g_loss: 3.1588 - kl_divergence: 0.3458
Epoch 156/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2716 - g_loss: 3.1673 - kl_divergence: 0.3470
Epoch 157/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2726 - g_loss: 3.2010 - kl_divergence: 0.3474
Epoch 158/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2777 - g_loss: 3.2242 - kl_divergence: 0.3471
Epoch 159/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2729 - g_loss: 3.2318 - kl_divergence: 0.3468
Epoch 160/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2733 - g_loss: 3.2283 - kl_divergence: 0.3461
Epoch 161/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2734 - g_loss: 3.2335 - kl_divergence: 0.3458
Epoch 162/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2706 - g_loss: 3.2452 - kl_divergence: 0.3446
Epoch 163/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2691 - g_loss: 3.2610 - kl_divergence: 0.3454
Epoch 164/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2688 - g_loss: 3.2782 - kl_divergence: 0.3470
Epoch 165/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2725 - g_loss: 3.2630 - kl_divergence: 0.3488
Epoch 166/200
1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.2720 - g_loss: 3.2585 - kl_divergence: 0.3461
Epoch 167/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2697 - g_loss: 3.3011 - kl_divergence: 0.3495
Epoch 168/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2679 - g_loss: 3.2953 - kl_divergence: 0.3490
Epoch 169/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2714 - g_loss: 3.2939 - kl_divergence: 0.3484
Epoch 170/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2682 - g_loss: 3.3109 - kl_divergence: 0.3518
Epoch 171/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2687 - g_loss: 3.3092 - kl_divergence: 0.3532
Epoch 172/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2688 - g_loss: 3.3288 - kl_divergence: 0.3484
Epoch 173/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2662 - g_loss: 3.3398 - kl_divergence: 0.3483
Epoch 174/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2673 - g_loss: 3.3607 - kl_divergence: 0.3484
Epoch 175/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2629 - g_loss: 3.3719 - kl_divergence: 0.3501
Epoch 176/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2671 - g_loss: 3.3627 - kl_divergence: 0.3453
Epoch 177/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2626 - g_loss: 3.4028 - kl_divergence: 0.3446
Epoch 178/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2621 - g_loss: 3.4167 - kl_divergence: 0.3461
Epoch 179/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2632 - g_loss: 3.3936 - kl_divergence: 0.3508
Epoch 180/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2650 - g_loss: 3.3966 - kl_divergence: 0.3462
Epoch 181/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2631 - g_loss: 3.4136 - kl_divergence: 0.3490
Epoch 182/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2567 - g_loss: 3.4415 - kl_divergence: 0.3490
Epoch 183/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2609 - g_loss: 3.4225 - kl_divergence: 0.3505
Epoch 184/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2638 - g_loss: 3.4268 - kl_divergence: 0.3471
Epoch 185/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2574 - g_loss: 3.4730 - kl_divergence: 0.3484
Epoch 186/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2633 - g_loss: 3.4415 - kl_divergence: 0.3492
Epoch 187/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2632 - g_loss: 3.4339 - kl_divergence: 0.3469
Epoch 188/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2581 - g_loss: 3.4664 - kl_divergence: 0.3471
Epoch 189/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2596 - g_loss: 3.4776 - kl_divergence: 0.3480
Epoch 190/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2572 - g_loss: 3.4819 - kl_divergence: 0.3475
Epoch 191/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2612 - g_loss: 3.4703 - kl_divergence: 0.3525
Epoch 192/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2571 - g_loss: 3.4966 - kl_divergence: 0.3492
Epoch 193/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 0.2538 - g_loss: 3.5485 - kl_divergence: 0.3495
Epoch 194/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2518 - g_loss: 3.5523 - kl_divergence: 0.3489
Epoch 195/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2567 - g_loss: 3.5505 - kl_divergence: 0.3480
Epoch 196/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2543 - g_loss: 3.5229 - kl_divergence: 0.3476
Epoch 197/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2543 - g_loss: 3.5500 - kl_divergence: 0.3482
Epoch 198/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2533 - g_loss: 3.5558 - kl_divergence: 0.3512
Epoch 199/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2536 - g_loss: 3.5719 - kl_divergence: 0.3525
Epoch 200/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 0.2524 - g_loss: 3.5922 - kl_divergence: 0.3531
Out[ ]:
<keras.callbacks.History at 0x22a08b19790>
In [ ]:
input_folder = './output/models/cgan/generated/'  # Replace with your frames directory
output_file = 'cgan_output_video.mp4'      # Replace with your desired output file path
gnnf.create_video_from_frames(input_folder, output_file)

Wasserstein GAN (wGAN)

WGAN is an advanced type of Generative Adversarial Network (GAN) architecture that addresses some of the training instability issues faced by standard GANs, which uses a different loss function and training approach compared to the original GAN formulation.

WGANs offer several advantages, including:

  • More stable training: Less prone to vanishing/exploding gradients, leading to smoother training and better convergence.
  • Theoretically sound loss function: The Wasserstein distance provides a clearer interpretation of the distance between real and generated data distributions.
  • Can potentially generate higher-quality images: Depending on the specific WGAN implementation and dataset, it can produce more realistic and diverse images.

Like all GANs, WGANs have two main components:

  • Generator: Creates new data samples (e.g., images) that mimic the real data distribution.
  • Critic (discriminator): Evaluates how well the generated data resembles real data.

Unlike standard GANs, the critic in WGANs doesn't directly classify samples as real or fake. Instead, it estimates the "distance" between the real and generated data distributions using the Wasserstein distance.

The generator and critic are trained in an adversarial way:

The generator aims to minimize the Wasserstein distance, effectively fooling the critic into believing its generated data is real. The critic aims to maximize the Wasserstein distance, accurately distinguishing real and generated data. Over training, the generator gets better at creating realistic data, while the critic improves its ability to discriminate. This competition leads to improved quality and diversity in the generated data.

The formula for Wasserstein loss is as follows: \begin{align*} &Wasserstein \, Loss = \frac{1}{N} \sum_{i=1}^{N} \left[ y_i \cdot f(x_i) - f(g(z_i)) \right] \\ &\text{where}\\ &N \text{ is the number of observations,} \\ &y_i \text{ is the label indicating real (1) or generated (-1) for the } i^{\text{th}} \text{ observation,} \\ &f \text{ is the critic (or discriminator) network's output,} \\ &x_i \text{ is the real data instance,} \\ &g(z_i) \text{ is the generated data instance from noise vector } z_i. \end{align*}

In [ ]:
def critic_loss(real_output, fake_output):
    return tf.reduce_mean(fake_output) - tf.reduce_mean(real_output)

def generator_loss(fake_output):
    return -tf.reduce_mean(fake_output)

from tensorflow.keras.constraints import Constraint

class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value

    def __call__(self, weights):
        return tf.clip_by_value(weights, -self.clip_value, self.clip_value)
In [ ]:
num_classes = 10
class wGAN(GAN_template):
    def __init__(self, latent_dim):
        super().__init__(latent_dim)
        self.num_classes = num_classes
        self.CRITIC_UPDATES = 5

    def define_discriminator(self, in_shape=(32,32,3)):
        # Image input
        constraint = ClipConstraint(0.01)
        image_input = Input(shape=in_shape)

        # Label input and embedding
        label_input = Input(shape=(1,))
        label_embedding = Embedding(num_classes, np.prod(in_shape))(label_input)
        label_embedding = Dense(np.prod(in_shape))(label_embedding)
        label_embedding = Reshape(in_shape)(label_embedding)

        # Concatenate image and label
        concatenated = Concatenate()([image_input, label_embedding])

        # Discriminator model
        x = Conv2D(64, (3,3), padding='same', kernel_constraint=constraint)(concatenated)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_constraint=constraint)(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_constraint=constraint)(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(64, (3,3), strides=(2,2), padding='same', kernel_constraint=constraint)(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Flatten()(x)
        x = Dropout(0.4)(x)
        output = Dense(1)(x)

        # Define and compile model
        model = Model(inputs=[image_input, label_input], outputs=output)
        model.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5), metrics=['accuracy'])
        return model

    def define_generator(self, latent_dim):
        model = Sequential()
            # Model for processing the labels
        label_input = tf.keras.Input(shape=(1,), dtype='int32')
        label_embedding = Embedding(num_classes, latent_dim)(label_input)
        label_embedding = Flatten()(label_embedding)

        # Model for processing the latent vector
        latent_input = tf.keras.Input(shape=(latent_dim,))

        # Combine label and latent inputs
        merged_input = Concatenate()([latent_input, label_embedding])

        # Sequential model for the generator
        generator = Sequential([
            Dense(8 * 8 * 256, input_shape=(latent_dim * 2,)),
            LeakyReLU(alpha=0.2),
            Reshape((8, 8, 256)),
            Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
            BatchNormalization(),
            LeakyReLU(alpha=0.2),
            Conv2DTranspose(64, (4, 4), strides=(2, 2), padding='same'),
            BatchNormalization(),
            LeakyReLU(alpha=0.2),
            Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
        ])

        # Pass the merged input through the generator model
        generated_image = generator(merged_input)

        # Final cGAN generator model
        model = Model(inputs=[latent_input, label_input], outputs=generated_image)
        return model

    def train_step(self, data):
        # Unpack the data
        if isinstance(data, tuple):
            real_images, real_labels = data
        else:
            real_images = data
            real_labels = tf.random.uniform([tf.shape(real_images)[0]], minval=0, maxval=self.num_classes, dtype=tf.int32)      
        batch_size = tf.shape(real_images)[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        fake_labels = tf.random.uniform([batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32)        
        # Critic updates
        for _ in range(self.CRITIC_UPDATES):
            with tf.GradientTape() as tape:
                generated_images = self.generator([random_latent_vectors, fake_labels], training=True)
                real_output = self.discriminator([real_images, real_labels], training=True)
                fake_output = self.discriminator([generated_images, fake_labels], training=True)
                c_loss = critic_loss(real_output, fake_output)      
            c_grads = tape.gradient(c_loss, self.discriminator.trainable_weights)
            self.d_optimizer.apply_gradients(zip(c_grads, self.discriminator.trainable_weights))

        with tf.GradientTape() as tape:
            generated_images = self.generator([random_latent_vectors, fake_labels], training=True)
            fake_output = self.discriminator([generated_images, fake_labels], training=True)
            g_loss = generator_loss(fake_output)  # Ensure this is a suitable loss for WGAN
            kl_loss = self.kl_divergence(real_images, generated_images)
        g_grads = tape.gradient(g_loss, self.generator.trainable_weights)  # Include KL divergence in gradients
        self.g_optimizer.apply_gradients(zip(g_grads, self.generator.trainable_weights))        
        # Update metrics
        self.d_loss_tracker.update_state(c_loss)
        self.g_loss_tracker.update_state(g_loss)
        self.kl_divergence_tracker.update_state(kl_loss)        

        return{
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
            "kl_divergence": self.kl_divergence_tracker.result()
            }

    @staticmethod
    def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
        fig = plt.figure(figsize=(20, 15))
        gs = fig.add_gridspec(10, 10, height_ratios=[1]*10, width_ratios=[1]*10, hspace=0.25, wspace=0.2)
        examples = (examples + 1) / 2.0
        class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

        for i in range(10*5):  # 3 images per class, 10 classes
            class_index = i // 5  # Determine class based on order
            ax = fig.add_subplot(gs[i % 5, class_index])
            # print(i % 5, class_index)
            ax.axis('off')
            ax.imshow(examples[i])
            # Add class label text for the first image of each class
            if i % 5 == 0:
                ax.set_title(class_names[class_index], fontsize=8)

        # Plot for discriminator losses
        ax_loss = fig.add_subplot(gs[5:8, 0:3])
        ax_loss.plot(d_losses, label="Discriminator Loss")
        ax_loss.set_title("Discriminator Loss")
        # Plot for generator losses
        ax_g_loss = fig.add_subplot(gs[5:8, 3:7])
        ax_g_loss.plot(g_losses, label="Generator Loss")
        ax_g_loss.set_title("Generator Loss")

        ax_kl_div = fig.add_subplot(gs[5:8, 7:10])
        ax_kl_div.plot(kl_div, label="KL Divergence")
        ax_kl_div.set_title("KL Divergence")

        plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.95)
        plt.tight_layout()
        plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
        plt.close()

    @staticmethod
    def generate_fake_samples(self, generator, n_samples=5, latent_dim=100):
        X, y = [], []
        for class_label in range(10):  # CIFAR-10 has 10 classes
            # Generate latent points
            x_input = np.random.randn(latent_dim * n_samples)
            x_input = x_input.reshape(n_samples, latent_dim)
            # Create class labels
            labels = np.full((n_samples, 1), class_label)
            # Generate images
            images = generator.predict([x_input, labels], verbose=0)
            X.extend(images)
            y.extend(labels)
        return np.asarray(X), np.asarray(y)

wgan = wGAN(latent_dim=100)
wgan.compile(
    d_optimizer=RMSprop(learning_rate=0.0003),
    g_optimizer=RMSprop(learning_rate=0.0003),
    loss_fn=BinaryCrossentropy(from_logits=True),
)
wgan_callback = CustomCallback(d_losses = wgan.d_loss_list, g_losses = wgan.g_loss_list, kl_div=wgan.kl_div_list, model = wgan, filepath = "output/models/wgan/")
wgan.fit(X_train, epochs = 50, callbacks = [wgan_callback])
Epoch 1/50
1562/1562 [==============================] - 151s 92ms/step - d_loss: -2511.8191 - g_loss: 11509.7070 - kl_divergence: 0.5411
Epoch 2/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -38174.3750 - g_loss: 75383.7422 - kl_divergence: 0.5661
Epoch 3/50
1562/1562 [==============================] - 143s 91ms/step - d_loss: -53997.6914 - g_loss: 155898.1562 - kl_divergence: 0.5868
Epoch 4/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -174930.6094 - g_loss: -50041.0430 - kl_divergence: 0.6316
Epoch 5/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -541265.0625 - g_loss: -2412482.0000 - kl_divergence: 0.5908
Epoch 6/50
1562/1562 [==============================] - 143s 92ms/step - d_loss: -730157.8125 - g_loss: 77811.3906 - kl_divergence: 0.5558
Epoch 7/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -1167884.8750 - g_loss: -5160780.0000 - kl_divergence: 0.5474
Epoch 8/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -1581006.0000 - g_loss: 1626707.7500 - kl_divergence: 0.5673
Epoch 9/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -2397018.2500 - g_loss: -2872515.0000 - kl_divergence: 0.5713
Epoch 10/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -2372221.0000 - g_loss: -9411951.0000 - kl_divergence: 0.5423
Epoch 11/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -3415741.5000 - g_loss: -6726378.0000 - kl_divergence: 0.5977
Epoch 12/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -4656259.0000 - g_loss: -18112762.0000 - kl_divergence: 0.5754
Epoch 13/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -7106436.0000 - g_loss: 32966476.0000 - kl_divergence: 0.5420
Epoch 14/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -11302557.0000 - g_loss: 61351872.0000 - kl_divergence: 0.5448
Epoch 15/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -17009860.0000 - g_loss: 67873368.0000 - kl_divergence: 0.4710
Epoch 16/50
1562/1562 [==============================] - 143s 91ms/step - d_loss: -11430128.0000 - g_loss: 26108232.0000 - kl_divergence: 0.4775
Epoch 17/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -20488420.0000 - g_loss: -12855508.0000 - kl_divergence: 0.5432
Epoch 18/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -127779392.0000 - g_loss: -469266592.0000 - kl_divergence: 0.5487
Epoch 19/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -120902504.0000 - g_loss: -562537280.0000 - kl_divergence: 0.5074
Epoch 20/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -54040672.0000 - g_loss: -150013616.0000 - kl_divergence: 0.4903
Epoch 21/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -31215888.0000 - g_loss: -30920370.0000 - kl_divergence: 0.4663
Epoch 22/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -133609240.0000 - g_loss: -782933440.0000 - kl_divergence: 0.4160
Epoch 23/50
1562/1562 [==============================] - 143s 91ms/step - d_loss: -188014512.0000 - g_loss: -852295680.0000 - kl_divergence: 0.4265
Epoch 24/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -139411168.0000 - g_loss: -618659520.0000 - kl_divergence: 0.4141
Epoch 25/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -48559948.0000 - g_loss: -223021840.0000 - kl_divergence: 0.4052
Epoch 26/50
1562/1562 [==============================] - 142s 91ms/step - d_loss: -68058672.0000 - g_loss: 60360144.0000 - kl_divergence: 0.4080
Epoch 27/50
1562/1562 [==============================] - 140s 89ms/step - d_loss: -302390752.0000 - g_loss: -1766296064.0000 - kl_divergence: 0.4210
Epoch 28/50
1562/1562 [==============================] - 140s 90ms/step - d_loss: -216782304.0000 - g_loss: 1197224320.0000 - kl_divergence: 0.4133
Epoch 29/50
1562/1562 [==============================] - 140s 89ms/step - d_loss: -298702592.0000 - g_loss: 1927712000.0000 - kl_divergence: 0.4263
Epoch 30/50
1562/1562 [==============================] - 140s 90ms/step - d_loss: -372681088.0000 - g_loss: 2415065088.0000 - kl_divergence: 0.4424
Epoch 31/50
1562/1562 [==============================] - 141s 90ms/step - d_loss: -416613312.0000 - g_loss: 2832033280.0000 - kl_divergence: 0.4723
Epoch 32/50
1562/1562 [==============================] - 140s 89ms/step - d_loss: -461932640.0000 - g_loss: 159228352.0000 - kl_divergence: 0.4745
Epoch 33/50
1562/1562 [==============================] - 140s 90ms/step - d_loss: -571692096.0000 - g_loss: -3261810176.0000 - kl_divergence: 0.4536
Epoch 34/50
1562/1562 [==============================] - 148s 94ms/step - d_loss: -488175552.0000 - g_loss: -2884788224.0000 - kl_divergence: 0.4420
Epoch 35/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -768050432.0000 - g_loss: -3712435456.0000 - kl_divergence: 0.4111
Epoch 36/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -858683200.0000 - g_loss: -4325497344.0000 - kl_divergence: 0.3874
Epoch 37/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -818794432.0000 - g_loss: -4564397056.0000 - kl_divergence: 0.3683
Epoch 38/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -169840064.0000 - g_loss: -301608896.0000 - kl_divergence: 0.3777
Epoch 39/50
1562/1562 [==============================] - 153s 98ms/step - d_loss: -201049216.0000 - g_loss: -12258592.0000 - kl_divergence: 0.3465
Epoch 40/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -197536288.0000 - g_loss: 1037425536.0000 - kl_divergence: 0.3395
Epoch 41/50
1562/1562 [==============================] - 154s 99ms/step - d_loss: -211085920.0000 - g_loss: 39040424.0000 - kl_divergence: 0.3374
Epoch 42/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -204623872.0000 - g_loss: -304704864.0000 - kl_divergence: 0.3552
Epoch 43/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -295322240.0000 - g_loss: -1090108672.0000 - kl_divergence: 0.3444
Epoch 44/50
1562/1562 [==============================] - 153s 98ms/step - d_loss: -284562528.0000 - g_loss: -893282368.0000 - kl_divergence: 0.3511
Epoch 45/50
1562/1562 [==============================] - 150s 96ms/step - d_loss: -346499520.0000 - g_loss: -679465408.0000 - kl_divergence: 0.3725
Epoch 46/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -1048624512.0000 - g_loss: 4641689088.0000 - kl_divergence: 0.3748
Epoch 47/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -1603169536.0000 - g_loss: 8545432576.0000 - kl_divergence: 0.3751
Epoch 48/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -376864448.0000 - g_loss: 88227160.0000 - kl_divergence: 0.3829
Epoch 49/50
1562/1562 [==============================] - 152s 97ms/step - d_loss: -440872544.0000 - g_loss: -1773806464.0000 - kl_divergence: 0.3780
Epoch 50/50
1562/1562 [==============================] - 153s 98ms/step - d_loss: -1162849024.0000 - g_loss: 7700485632.0000 - kl_divergence: 0.3899
Out[ ]:
<keras.callbacks.History at 0x1a18e4d76a0>

However, this implementation of wGAN seems to be ineffective, with the model crashing even after serveral attempts to tune it properly. Hence, we shall state that for the given architecture, using wGAN is ineffective, and does not work.


Hinge GAN

Earlier on, we attempted (but failed) to use wGAN to solve the task at hand. Now, we shall try an alternative model to try to solve the given task. for this, we shall use Hinge GAN.

Hinge GANs are a type of Generative Adversarial Network (GAN) that utilize a "hinge loss" function to train their discriminator, the component responsible for judging real versus fake data. Unlike the standard binary cross-entropy loss in traditional GANs, hinge loss only penalizes the discriminator when it makes mistakes or fails to confidently distinguish between real and generated data. This approach offers several benefits:

  1. Improved Stability: Hinge loss avoids the vanishing gradient problem that can plague traditional GANs during training. This leads to smoother training and potentially faster convergence to stable models.

  2. Better Focus on Margins: By focusing on pushing real and fake data apart in the scoring space, hinge loss encourages the discriminator to pay more attention to the quality of generated data rather than just classifying them correctly. This can lead to sharper and more realistic generations.

However, hinge loss also comes with drawbacks. Its emphasis on margins can sometimes lead to "mode collapse," where the generator gets stuck producing only a limited variety of outputs. Carefully choosing hyperparameters and architectures can help mitigate this issue.

Thus, we use this model to hopefully stabilize the model, as well as to generate higher quality, realistic, and detailed data.

The formula for hinge loss is as follows:

\begin{align*} &D(x) = \max(0, 1 - t \cdot f(x)) \\ &\text{where} \\ &t = \begin{cases} 1 & \text{for real data (positive class)} \\ -1 & \text{for generated data (negative class)} \end{cases} \\ &f(x) \text{ is the discriminator's output for an input } x. \end{align*}

While in a traditional Hinge GAN, one might use both hinge loss for both the generator and discriminator, it may not yield the best results for the network. This is as Hinge loss, as its name suggests, focuses on pushing real and fake data apart in the scoring space. While this is beneficial for the discriminator to learn to tell them apart, it's not directly relevant to the generator's goal of creating realistic data. Penalizing the generator based on the discriminator's confidence, even for good outputs, can hinder its learning and potentially lead to suboptimal results.

In [ ]:
num_classes = 10
class hingeGAN(GAN_template):
    def __init__(self, latent_dim):
        super().__init__(latent_dim)
        self.num_classes = num_classes

    def define_discriminator(self, in_shape=(32,32,3)):
        # Image input
        image_input = Input(shape=in_shape)

        # Label input and embedding
        label_input = Input(shape=(1,))
        label_embedding = Embedding(num_classes, np.prod(in_shape))(label_input)
        label_embedding = Dense(np.prod(in_shape))(label_embedding)
        label_embedding = Reshape(in_shape)(label_embedding)

        # Concatenate image and label
        concatenated = Concatenate()([image_input, label_embedding])

        # Discriminator model
        x = Conv2D(64, (3,3), padding='same')(concatenated)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(64, (3,3), strides=(2,2), padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(64, (3,3), strides=(2,2), padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Conv2D(128, (3,3), strides=(2,2), padding='same', kernel_regularizer=l1_l2(l1=0.001, l2=0.001))(x)
        x = LeakyReLU(alpha=0.2)(x)
        x = Flatten()(x)
        x = Dropout(0.4)(x)
        output = Dense(1, activation='linear')(x)

        # Define and compile model
        model = Model(inputs=[image_input, label_input], outputs=output)
        return model

    def define_generator(self, latent_dim):
        model = Sequential()
            # Model for processing the labels
        label_input = tf.keras.Input(shape=(1,), dtype='int32')
        label_embedding = Embedding(num_classes, latent_dim)(label_input)
        label_embedding = Flatten()(label_embedding)

        # Model for processing the latent vector
        latent_input = tf.keras.Input(shape=(latent_dim,))

        # Combine label and latent inputs
        merged_input = Concatenate()([latent_input, label_embedding])

        # Sequential model for the generator
        generator = Sequential([
            Dense(8 * 8 * 256, input_shape=(latent_dim * 2,)),
            LeakyReLU(alpha=0.2),
            Reshape((8, 8, 256)),
            Conv2DTranspose(256, (4, 4), strides=(2, 2), padding='same'),
            BatchNormalization(),
            LeakyReLU(alpha=0.2),
            Conv2DTranspose(128, (4, 4), strides=(2, 2), padding='same'),
            BatchNormalization(),
            LeakyReLU(alpha=0.2),
            Conv2DTranspose(3, (3, 3), activation='tanh', padding='same')
        ])

        # Pass the merged input through the generator model
        generated_image = generator(merged_input)

        # Final cGAN generator model
        model = Model(inputs=[latent_input, label_input], outputs=generated_image)
        return model

    def train_step(self, data):
        # Unpack the data. Its structure depends on your dataset and
        # whether it includes labels
        if isinstance(data, tuple):
               real_images, real_labels = data
        else:
            real_images = data
            real_labels = tf.random.uniform([tf.shape(real_images)[0]], minval=0, maxval=self.num_classes, dtype=tf.int32)

        batch_size = real_images[0].shape[0]
        random_latent_vectors = tf.random.normal(shape=(batch_size, self.latent_dim))
        
        # real_images = (real_images - 127.5) / 127.5  # Normalize to [-1, 1] if your real_images are in [0, 255]
        real_images = tf.reshape(real_images, [batch_size, 32, 32, 3])

        # Generate labels for fake images if needed
        fake_labels = tf.random.uniform([batch_size], minval=0, maxval=self.num_classes, dtype=tf.int32)

        # Generate fake images
        generated_images = self.generator([random_latent_vectors, fake_labels])
        combined_images = tf.concat([generated_images, tf.cast(real_images, tf.float32)], axis=0)

        real_labels = tf.squeeze(real_labels)
        combined_labels = tf.concat([tf.cast(fake_labels, 'uint8'), real_labels], axis=0)  # Concatenate labels as well


        # Labels for discriminator to discriminate real from fake images
        discriminator_labels = tf.concat(
            [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], axis=0
        )

        # Train the discriminator
        with tf.GradientTape() as tape:
            predictions_on_real = self.discriminator([real_images, real_labels])
            predictions_on_fake = self.discriminator([generated_images, fake_labels])

            # Hinge loss for the discriminator
            d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - predictions_on_real))
            d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + predictions_on_fake))
            d_loss = d_loss_real + d_loss_fake
        grads = tape.gradient(d_loss, self.discriminator.trainable_weights)
        self.d_optimizer.apply_gradients(zip(grads, self.discriminator.trainable_weights))

        # Misleading labels for the generator
        misleading_labels = tf.ones((batch_size, 1))

        # Train the generator
        with tf.GradientTape() as tape:
            generated_images = self.generator([random_latent_vectors, fake_labels])
            predictions = self.discriminator([generated_images, fake_labels])
            g_loss = -tf.reduce_mean(predictions)
            kl_loss = self.kl_divergence(real_images, generated_images)
        grads = tape.gradient(g_loss, self.generator.trainable_weights)
        self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights))

        # Update metrics
        self.d_loss_tracker.update_state(d_loss)
        self.g_loss_tracker.update_state(g_loss)
        self.kl_divergence_tracker.update_state(kl_loss)

        return {
            "d_loss": self.d_loss_tracker.result(),
            "g_loss": self.g_loss_tracker.result(),
            "kl_divergence": self.kl_divergence_tracker.result()
        }
    @staticmethod
    def save_plot(examples, epoch, d_losses, g_losses, kl_div, filepath):
        fig = plt.figure(figsize=(20, 15))
        gs = fig.add_gridspec(10, 10, height_ratios=[1]*10, width_ratios=[1]*10, hspace=0.25, wspace=0.2)
        examples = (examples + 1) / 2.0
        class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

        for i in range(10*5):  # 3 images per class, 10 classes
            class_index = i // 5  # Determine class based on order
            ax = fig.add_subplot(gs[i % 5, class_index])
            # print(i % 5, class_index)
            ax.axis('off')
            ax.imshow(examples[i])
            # Add class label text for the first image of each class
            if i % 5 == 0:
                ax.set_title(class_names[class_index], fontsize=8)

        # Plot for discriminator losses
        ax_loss = fig.add_subplot(gs[5:8, 0:3])
        ax_loss.plot(d_losses, label="Discriminator Loss")
        ax_loss.set_title("Discriminator Loss")
        # Plot for generator losses
        ax_g_loss = fig.add_subplot(gs[5:8, 3:7])
        ax_g_loss.plot(g_losses, label="Generator Loss")
        ax_g_loss.set_title("Generator Loss")

        ax_kl_div = fig.add_subplot(gs[5:8, 7:10])
        ax_kl_div.plot(kl_div, label="KL Divergence")
        ax_kl_div.set_title("KL Divergence")

        plt.suptitle(f"Epoch {epoch+1}", fontsize=18, y=0.95)
        plt.tight_layout()
        plt.savefig(f"{filepath}generated/generated_plot_e{epoch+1}.png", bbox_inches='tight')
        plt.close()

    @staticmethod
    def generate_fake_samples(self, generator, n_samples=5, latent_dim=100):
        X, y = [], []
        for class_label in range(10):  # CIFAR-10 has 10 classes
            # Generate latent points
            x_input = np.random.randn(latent_dim * n_samples)
            x_input = x_input.reshape(n_samples, latent_dim)
            # Create class labels
            labels = np.full((n_samples, 1), class_label)
            # Generate images
            images = generator.predict([x_input, labels], verbose=0)
            X.extend(images)
            y.extend(labels)
        return np.asarray(X), np.asarray(y)

hinge_gan = hingeGAN(latent_dim=100)
hinge_gan.compile(
    d_optimizer=Adam(learning_rate=0.0003),
    g_optimizer=Adam(learning_rate=0.0003),
    loss_fn=BinaryCrossentropy(),
)
hinge_gan_callback = CustomCallback(d_losses = hinge_gan.d_loss_list, g_losses = hinge_gan.g_loss_list, kl_div=hinge_gan.kl_div_list, model = hinge_gan, filepath = "output/models/hinge_gan/")
hinge_gan.fit(X_train, epochs = 200, callbacks = [hinge_gan_callback])
Epoch 1/200
1562/1562 [==============================] - 52s 32ms/step - d_loss: 1.4068 - g_loss: 1.1703 - kl_divergence: 0.7009
Epoch 2/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.2997 - g_loss: 1.1075 - kl_divergence: 0.5582
Epoch 3/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.5804 - g_loss: 0.7518 - kl_divergence: 0.5271
Epoch 4/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.4858 - g_loss: 0.8374 - kl_divergence: 0.4826
Epoch 5/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.3063 - g_loss: 1.1355 - kl_divergence: 0.5613
Epoch 6/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.5198 - g_loss: 0.8947 - kl_divergence: 0.4339
Epoch 7/200
1562/1562 [==============================] - 54s 34ms/step - d_loss: 1.4658 - g_loss: 0.8365 - kl_divergence: 0.4356
Epoch 8/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6505 - g_loss: 0.5848 - kl_divergence: 0.3753
Epoch 9/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6594 - g_loss: 0.6821 - kl_divergence: 0.3933
Epoch 10/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6946 - g_loss: 0.5013 - kl_divergence: 0.3628
Epoch 11/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.7023 - g_loss: 0.5359 - kl_divergence: 0.3619
Epoch 12/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.6076 - g_loss: 0.7303 - kl_divergence: 0.3474
Epoch 13/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.5535 - g_loss: 0.7393 - kl_divergence: 0.3444
Epoch 14/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.4983 - g_loss: 0.7002 - kl_divergence: 0.3393
Epoch 15/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.5308 - g_loss: 0.6908 - kl_divergence: 0.3411
Epoch 16/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.5017 - g_loss: 0.6909 - kl_divergence: 0.3418
Epoch 17/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3154 - g_loss: 0.8463 - kl_divergence: 0.3460
Epoch 18/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1563 - g_loss: 1.0569 - kl_divergence: 0.3702
Epoch 19/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1680 - g_loss: 1.0744 - kl_divergence: 0.3824
Epoch 20/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1684 - g_loss: 1.0503 - kl_divergence: 0.3869
Epoch 21/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.2071 - g_loss: 0.9494 - kl_divergence: 0.4044
Epoch 22/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 1.2979 - g_loss: 0.8252 - kl_divergence: 0.3918
Epoch 23/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3386 - g_loss: 0.7789 - kl_divergence: 0.3829
Epoch 24/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3664 - g_loss: 0.7522 - kl_divergence: 0.3735
Epoch 25/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3652 - g_loss: 0.7219 - kl_divergence: 0.3888
Epoch 26/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.3442 - g_loss: 0.7601 - kl_divergence: 0.3748
Epoch 27/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.3432 - g_loss: 0.7415 - kl_divergence: 0.3680
Epoch 28/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.3300 - g_loss: 0.7771 - kl_divergence: 0.3579
Epoch 29/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.2607 - g_loss: 0.8332 - kl_divergence: 0.3576
Epoch 30/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 1.2657 - g_loss: 0.8211 - kl_divergence: 0.3631
Epoch 31/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.2001 - g_loss: 0.8541 - kl_divergence: 0.3567
Epoch 32/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1823 - g_loss: 0.8885 - kl_divergence: 0.3626
Epoch 33/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1570 - g_loss: 0.8930 - kl_divergence: 0.3535
Epoch 34/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1634 - g_loss: 0.8873 - kl_divergence: 0.3482
Epoch 35/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1483 - g_loss: 0.8892 - kl_divergence: 0.3487
Epoch 36/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1788 - g_loss: 0.8548 - kl_divergence: 0.3495
Epoch 37/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1767 - g_loss: 0.8683 - kl_divergence: 0.3467
Epoch 38/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2036 - g_loss: 0.8460 - kl_divergence: 0.3521
Epoch 39/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2442 - g_loss: 0.8155 - kl_divergence: 0.3522
Epoch 40/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.2047 - g_loss: 0.8271 - kl_divergence: 0.3573
Epoch 41/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2177 - g_loss: 0.8495 - kl_divergence: 0.3535
Epoch 42/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2080 - g_loss: 0.8276 - kl_divergence: 0.3514
Epoch 43/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2192 - g_loss: 0.8659 - kl_divergence: 0.3538
Epoch 44/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2127 - g_loss: 0.8250 - kl_divergence: 0.3563
Epoch 45/200
1562/1562 [==============================] - 49s 32ms/step - d_loss: 1.2102 - g_loss: 0.8466 - kl_divergence: 0.3498
Epoch 46/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2119 - g_loss: 0.8302 - kl_divergence: 0.3522
Epoch 47/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2072 - g_loss: 0.8404 - kl_divergence: 0.3503
Epoch 48/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2030 - g_loss: 0.8301 - kl_divergence: 0.3504
Epoch 49/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2062 - g_loss: 0.8528 - kl_divergence: 0.3473
Epoch 50/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1961 - g_loss: 0.8442 - kl_divergence: 0.3496
Epoch 51/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2025 - g_loss: 0.8651 - kl_divergence: 0.3466
Epoch 52/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1947 - g_loss: 0.8362 - kl_divergence: 0.3529
Epoch 53/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1871 - g_loss: 0.8579 - kl_divergence: 0.3443
Epoch 54/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1739 - g_loss: 0.8694 - kl_divergence: 0.3435
Epoch 55/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.2025 - g_loss: 0.8642 - kl_divergence: 0.3440
Epoch 56/200
1562/1562 [==============================] - 49s 31ms/step - d_loss: 1.1924 - g_loss: 0.8862 - kl_divergence: 0.3461
Epoch 57/200
1562/1562 [==============================] - 51s 33ms/step - d_loss: 1.1943 - g_loss: 0.8427 - kl_divergence: 0.3479
Epoch 58/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1770 - g_loss: 0.8710 - kl_divergence: 0.3549
Epoch 59/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1782 - g_loss: 0.8660 - kl_divergence: 0.3445
Epoch 60/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1712 - g_loss: 0.8715 - kl_divergence: 0.3423
Epoch 61/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1678 - g_loss: 0.8867 - kl_divergence: 0.3435
Epoch 62/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1596 - g_loss: 0.8751 - kl_divergence: 0.3466
Epoch 63/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1657 - g_loss: 0.8960 - kl_divergence: 0.3444
Epoch 64/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1642 - g_loss: 0.8928 - kl_divergence: 0.3463
Epoch 65/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1646 - g_loss: 0.8916 - kl_divergence: 0.3468
Epoch 66/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1600 - g_loss: 0.8949 - kl_divergence: 0.3432
Epoch 67/200
1562/1562 [==============================] - 52s 33ms/step - d_loss: 1.1828 - g_loss: 0.9018 - kl_divergence: 0.3432
Epoch 68/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1559 - g_loss: 0.8802 - kl_divergence: 0.3491
Epoch 69/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1576 - g_loss: 0.9157 - kl_divergence: 0.3448
Epoch 70/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1659 - g_loss: 0.8853 - kl_divergence: 0.3429
Epoch 71/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1595 - g_loss: 0.9050 - kl_divergence: 0.3417
Epoch 72/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1586 - g_loss: 0.8963 - kl_divergence: 0.3431
Epoch 73/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1713 - g_loss: 0.9004 - kl_divergence: 0.3425
Epoch 74/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1631 - g_loss: 0.8904 - kl_divergence: 0.3463
Epoch 75/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1568 - g_loss: 0.9570 - kl_divergence: 0.3459
Epoch 76/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1544 - g_loss: 0.8804 - kl_divergence: 0.3461
Epoch 77/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1442 - g_loss: 0.9052 - kl_divergence: 0.3444
Epoch 78/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1503 - g_loss: 0.9182 - kl_divergence: 0.3437
Epoch 79/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1501 - g_loss: 0.9152 - kl_divergence: 0.3443
Epoch 80/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1457 - g_loss: 0.9098 - kl_divergence: 0.3465
Epoch 81/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1458 - g_loss: 0.9083 - kl_divergence: 0.3472
Epoch 82/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1446 - g_loss: 0.9175 - kl_divergence: 0.3492
Epoch 83/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1347 - g_loss: 0.9641 - kl_divergence: 0.3436
Epoch 84/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1482 - g_loss: 0.9012 - kl_divergence: 0.3442
Epoch 85/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1334 - g_loss: 0.9249 - kl_divergence: 0.3455
Epoch 86/200
1562/1562 [==============================] - 53s 34ms/step - d_loss: 1.1392 - g_loss: 0.9208 - kl_divergence: 0.3465
Epoch 87/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1310 - g_loss: 0.9467 - kl_divergence: 0.3433
Epoch 88/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1334 - g_loss: 0.9430 - kl_divergence: 0.3436
Epoch 89/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1438 - g_loss: 0.9275 - kl_divergence: 0.3511
Epoch 90/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1332 - g_loss: 0.9193 - kl_divergence: 0.3465
Epoch 91/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1261 - g_loss: 0.9217 - kl_divergence: 0.3454
Epoch 92/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1287 - g_loss: 0.9298 - kl_divergence: 0.3465
Epoch 93/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1289 - g_loss: 0.9622 - kl_divergence: 0.3455
Epoch 94/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1365 - g_loss: 0.9233 - kl_divergence: 0.3472
Epoch 95/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1166 - g_loss: 0.9496 - kl_divergence: 0.3462
Epoch 96/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1111 - g_loss: 0.9699 - kl_divergence: 0.3430
Epoch 97/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1102 - g_loss: 0.9419 - kl_divergence: 0.3431
Epoch 98/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1112 - g_loss: 0.9633 - kl_divergence: 0.3443
Epoch 99/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1169 - g_loss: 0.9709 - kl_divergence: 0.3436
Epoch 100/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1140 - g_loss: 0.9564 - kl_divergence: 0.3420
Epoch 101/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1138 - g_loss: 0.9648 - kl_divergence: 0.3471
Epoch 102/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1131 - g_loss: 0.9483 - kl_divergence: 0.3412
Epoch 103/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1061 - g_loss: 0.9601 - kl_divergence: 0.3441
Epoch 104/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1016 - g_loss: 0.9799 - kl_divergence: 0.3454
Epoch 105/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1091 - g_loss: 0.9835 - kl_divergence: 0.3436
Epoch 106/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1042 - g_loss: 0.9706 - kl_divergence: 0.3435
Epoch 107/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1091 - g_loss: 0.9762 - kl_divergence: 0.3441
Epoch 108/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1133 - g_loss: 0.9590 - kl_divergence: 0.3473
Epoch 109/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.1072 - g_loss: 0.9735 - kl_divergence: 0.3448
Epoch 110/200
1562/1562 [==============================] - 54s 34ms/step - d_loss: 1.1011 - g_loss: 0.9659 - kl_divergence: 0.3434
Epoch 111/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0891 - g_loss: 0.9771 - kl_divergence: 0.3445
Epoch 112/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0943 - g_loss: 0.9774 - kl_divergence: 0.3464
Epoch 113/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0956 - g_loss: 0.9813 - kl_divergence: 0.3424
Epoch 114/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0931 - g_loss: 0.9837 - kl_divergence: 0.3460
Epoch 115/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0891 - g_loss: 1.0101 - kl_divergence: 0.3486
Epoch 116/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0906 - g_loss: 0.9853 - kl_divergence: 0.3543
Epoch 117/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0859 - g_loss: 0.9850 - kl_divergence: 0.3429
Epoch 118/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0752 - g_loss: 1.0073 - kl_divergence: 0.3436
Epoch 119/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0840 - g_loss: 0.9904 - kl_divergence: 0.3455
Epoch 120/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0879 - g_loss: 0.9863 - kl_divergence: 0.3484
Epoch 121/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0602 - g_loss: 1.0297 - kl_divergence: 0.3452
Epoch 122/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0681 - g_loss: 0.9994 - kl_divergence: 0.3433
Epoch 123/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0790 - g_loss: 1.0140 - kl_divergence: 0.3453
Epoch 124/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0661 - g_loss: 1.0104 - kl_divergence: 0.3429
Epoch 125/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0746 - g_loss: 1.0140 - kl_divergence: 0.3458
Epoch 126/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0631 - g_loss: 1.0134 - kl_divergence: 0.3437
Epoch 127/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0578 - g_loss: 1.0399 - kl_divergence: 0.3463
Epoch 128/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0543 - g_loss: 1.0314 - kl_divergence: 0.3422
Epoch 129/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0683 - g_loss: 1.0400 - kl_divergence: 0.3437
Epoch 130/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0517 - g_loss: 1.0381 - kl_divergence: 0.3462
Epoch 131/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0576 - g_loss: 1.0286 - kl_divergence: 0.3426
Epoch 132/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0572 - g_loss: 1.0345 - kl_divergence: 0.3432
Epoch 133/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0579 - g_loss: 1.0400 - kl_divergence: 0.3436
Epoch 134/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0512 - g_loss: 1.0413 - kl_divergence: 0.3452
Epoch 135/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0485 - g_loss: 1.0525 - kl_divergence: 0.3435
Epoch 136/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0488 - g_loss: 1.0618 - kl_divergence: 0.3425
Epoch 137/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0471 - g_loss: 1.0503 - kl_divergence: 0.3521
Epoch 138/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0446 - g_loss: 1.0619 - kl_divergence: 0.3450
Epoch 139/200
1562/1562 [==============================] - 55s 35ms/step - d_loss: 1.0402 - g_loss: 1.0503 - kl_divergence: 0.3445
Epoch 140/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0456 - g_loss: 1.0398 - kl_divergence: 0.3434
Epoch 141/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0426 - g_loss: 1.0680 - kl_divergence: 0.3460
Epoch 142/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0513 - g_loss: 1.0747 - kl_divergence: 0.3453
Epoch 143/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0383 - g_loss: 1.0625 - kl_divergence: 0.3447
Epoch 144/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0350 - g_loss: 1.0663 - kl_divergence: 0.3455
Epoch 145/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0323 - g_loss: 1.0645 - kl_divergence: 0.3459
Epoch 146/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0269 - g_loss: 1.0613 - kl_divergence: 0.3414
Epoch 147/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0286 - g_loss: 1.0912 - kl_divergence: 0.3431
Epoch 148/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0291 - g_loss: 1.0562 - kl_divergence: 0.3452
Epoch 149/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0133 - g_loss: 1.0851 - kl_divergence: 0.3460
Epoch 150/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0292 - g_loss: 1.0763 - kl_divergence: 0.3473
Epoch 151/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0199 - g_loss: 1.0885 - kl_divergence: 0.3437
Epoch 152/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0239 - g_loss: 1.0846 - kl_divergence: 0.3465
Epoch 153/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0202 - g_loss: 1.0812 - kl_divergence: 0.3435
Epoch 154/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0230 - g_loss: 1.0968 - kl_divergence: 0.3474
Epoch 155/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0128 - g_loss: 1.0996 - kl_divergence: 0.3483
Epoch 156/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0180 - g_loss: 1.0773 - kl_divergence: 0.3434
Epoch 157/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0084 - g_loss: 1.0954 - kl_divergence: 0.3450
Epoch 158/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0274 - g_loss: 1.0952 - kl_divergence: 0.3480
Epoch 159/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0117 - g_loss: 1.0851 - kl_divergence: 0.3566
Epoch 160/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0025 - g_loss: 1.1046 - kl_divergence: 0.3431
Epoch 161/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0130 - g_loss: 1.1085 - kl_divergence: 0.3425
Epoch 162/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0145 - g_loss: 1.1034 - kl_divergence: 0.3457
Epoch 163/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.0003 - g_loss: 1.1156 - kl_divergence: 0.3489
Epoch 164/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 0.9967 - g_loss: 1.1218 - kl_divergence: 0.3468
Epoch 165/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9979 - g_loss: 1.1281 - kl_divergence: 0.3450
Epoch 166/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 1.0051 - g_loss: 1.1207 - kl_divergence: 0.3417
Epoch 167/200
1562/1562 [==============================] - 51s 32ms/step - d_loss: 0.9906 - g_loss: 1.1343 - kl_divergence: 0.3428
Epoch 168/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9867 - g_loss: 1.1278 - kl_divergence: 0.3435
Epoch 169/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9873 - g_loss: 1.1408 - kl_divergence: 0.3460
Epoch 170/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9899 - g_loss: 1.1222 - kl_divergence: 0.3444
Epoch 171/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9945 - g_loss: 1.1174 - kl_divergence: 0.3499
Epoch 172/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 1.0026 - g_loss: 1.1190 - kl_divergence: 0.3439
Epoch 173/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9868 - g_loss: 1.1336 - kl_divergence: 0.3441
Epoch 174/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9915 - g_loss: 1.1241 - kl_divergence: 0.3441
Epoch 175/200
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.9822 - g_loss: 1.1513 - kl_divergence: 0.3437
Epoch 176/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9804 - g_loss: 1.1462 - kl_divergence: 0.3556
Epoch 177/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9786 - g_loss: 1.1407 - kl_divergence: 0.3465
Epoch 178/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9720 - g_loss: 1.1447 - kl_divergence: 0.3450
Epoch 179/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9881 - g_loss: 1.1381 - kl_divergence: 0.3457
Epoch 180/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9783 - g_loss: 1.1500 - kl_divergence: 0.3439
Epoch 181/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9749 - g_loss: 1.1415 - kl_divergence: 0.3549
Epoch 182/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9770 - g_loss: 1.1448 - kl_divergence: 0.3448
Epoch 183/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9796 - g_loss: 1.1363 - kl_divergence: 0.3487
Epoch 184/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9818 - g_loss: 1.1650 - kl_divergence: 0.3474
Epoch 185/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9675 - g_loss: 1.1617 - kl_divergence: 0.3460
Epoch 186/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9752 - g_loss: 1.1621 - kl_divergence: 0.3438
Epoch 187/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9677 - g_loss: 1.1749 - kl_divergence: 0.3468
Epoch 188/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9676 - g_loss: 1.1823 - kl_divergence: 0.3456
Epoch 189/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9656 - g_loss: 1.1680 - kl_divergence: 0.3436
Epoch 190/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9589 - g_loss: 1.1822 - kl_divergence: 0.3466
Epoch 191/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9564 - g_loss: 1.1801 - kl_divergence: 0.3467
Epoch 192/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9578 - g_loss: 1.1828 - kl_divergence: 0.3419
Epoch 193/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9534 - g_loss: 1.1911 - kl_divergence: 0.3455
Epoch 194/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9585 - g_loss: 1.1841 - kl_divergence: 0.3462
Epoch 195/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9560 - g_loss: 1.1895 - kl_divergence: 0.3441
Epoch 196/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9481 - g_loss: 1.1864 - kl_divergence: 0.3454
Epoch 197/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9547 - g_loss: 1.1838 - kl_divergence: 0.3428
Epoch 198/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9475 - g_loss: 1.1966 - kl_divergence: 0.3469
Epoch 199/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9551 - g_loss: 1.2710 - kl_divergence: 0.3457
Epoch 200/200
1562/1562 [==============================] - 50s 32ms/step - d_loss: 0.9684 - g_loss: 1.1674 - kl_divergence: 0.3464
Out[ ]:
<keras.callbacks.History at 0x1f892a588e0>

From the image generated, we can see that Hinge GAN actually performs quite well, with alot of images having some sort of identifiable object in them, while also having the detail that other models are unable to replicate.


Model Selection

Now, we shall make use of FID, as well as visual inspection to help us decide the best model. Since wGAN collapsed, it shall not be used in this comparison.

In [ ]:
dcgan_images, _ = dcgan.generate_fake_samples(dcgan, dcgan.generator, n_samples = 1000, latent_dim = 100)
dcgan_fid = calcFID(dcgan_images, num_images = 1000)
print(f"FID for DCGAN is {dcgan_fid}")
32/32 [==============================] - 22s 657ms/step
32/32 [==============================] - 21s 662ms/step
FID for DCGAN is 89.84017345710066
In [ ]:
cgan_images, _ = cgan.generate_fake_samples(cgan, cgan.generator, n_samples=100, latent_dim = 100)
cgan_fid = calcFID(cgan_images, num_images = 1000)
print(f"FID for cGAN is {cgan_fid}")
32/32 [==============================] - 26s 800ms/step
32/32 [==============================] - 25s 798ms/step
FID for cGAN is 67.36210799376948
In [ ]:
hinge_gan_images, _ = hinge_gan.generate_fake_samples(hinge_gan, hinge_gan.generator, n_samples = 100, latent_dim = 100)
hinge_gan_fid = calcFID(hinge_gan_images, num_images = 1000)
print(f"FID for Hinge GAN is {hinge_gan_fid}")
32/32 [==============================] - 5s 88ms/step
32/32 [==============================] - 3s 85ms/step
FID for Hinge GAN is 70.16999625299437

From the quantitative analysis, we can see that cGAN performs the best, however, on visual inspection, the images from Hinge GAN appear to be better and more detailed, while only performing slightly worse metric-wise. Hence, we shall use it as our final model, and improve it from here.


In [ ]:
def batch_images(images, batch_size):
    """Split the images into batches."""
    for i in range(0, len(images), batch_size):
        yield images[i:i + batch_size]

def display_images_in_grid(images, grid_size, title = None):
    """Display images in a grid."""
    fig, axs = plt.subplots(grid_size, grid_size, figsize=(15, 15))
    axs = axs.flatten()
    for img, ax in zip(images, axs):
        ax.imshow(img)
        ax.axis('off')
    if title is not None:
        plt.suptitle(title, y = 0.92)
    plt.show()

# Usage Example
all_images = hinge_gan.generate_fake_samples(hinge_gan, hinge_gan.generator, n_samples=10, latent_dim=100)[0]
all_images = (all_images + 1) / 2.0  # Scale images to [0, 1]
batches = list(batch_images(all_images, 100))

for batch in batches:
    display_images_in_grid(batch, 10)

Above are some images generated by Hinge GAN (100 per class). We can see that the images are quite realistic, and well-defined. However, there are images which are blobby/not as well defined. We shall try to optimize the model further to eliminate this.


Model Improvement

Now that we have completed the larger steps in model improvement, we shall now take the smaller, final step to allow the model to reach it's peak performance. We shall do this by hypertuning the parameters of the model. For our case, we shall only tune the optimizer, as the loss function is already very well suited for our use case. We shall tune the models using the Adam, SGD, and RMSProp optimizers.

Stochastic Gradient Descent (SGD) is one of the most basic and widely used optimization algorithms in machine learning and deep learning. It's a variant of gradient descent where instead of performing computations on the whole dataset – which can be computationally intensive for large datasets – SGD updates the model's weights using only a single or a few samples at a time. This makes the algorithm much faster and more suitable for large datasets. The formula for SGD is as follows. $$ w_{t+1} = w_t - \eta \cdot \nabla L(w_t, x_i, y_i) $$

\begin{aligned} \text{where} \\ w_{t+1} &\text{ is the updated weight vector at time } t+1, \\ w_t &\text{ is the weight vector at time } t, \\ \eta &\text{ is the learning rate, and} \\ \nabla L(w_t, x_i, y_i) &\text{ is the gradient of the loss function } L \text{ with respect to the weights } w, \text{ evaluated at a randomly chosen data point } (x_i, y_i). \end{aligned}

RMSProp, short for Root Mean Square Propagation, is an adaptive learning rate method proposed by Geoffrey Hinton. It addresses some of the limitations of SGD, especially in the context of minimizing functions in very high-dimensional spaces. RMSProp adjusts the learning rate for each weight based on the recent magnitudes of the gradients for that weight. This means that the learning rate is reduced for weights that consistently receive high gradients, which helps in faster convergence especially in situations involving oscillations. The formula for RMSProp is as follows.

\begin{aligned} v_{t+1} &= \beta \cdot v_t + (1 - \beta) \cdot (\nabla L(w_t))^2 \\ w_{t+1} &= w_t - \frac{\eta}{\sqrt{v_{t+1} + \epsilon}} \cdot \nabla L(w_t) \\ \text{where} \\ v_{t+1} &\text{ is the exponentially decaying average of squared gradients} \\ w_{t+1} &\text{ is the updated weight vector at time } t+1 \\ w_t &\text{ is the weight vector at time } t \\ \eta &\text{ is the learning rate} \\ \beta &\text{ is the decay rate, controlling the moving average of squared gradients} \\ \epsilon &\text{ is a small number to prevent division by zero} \\ \nabla L(w_t) &\text{ is the gradient of the loss function } L \text{ with respect to the weights } w \text{ at time } t \end{aligned}

Despite Adam and RMSProp being both able to adjust their learning rates, the values which they start at are still quite important to how the model performs. Hence, we shall do 3 permutations for each optimizer, and see how they improve from there.

In [ ]:
from tensorflow.keras.optimizers import SGD
from IPython.display import clear_output

optimizer_list = [
    Adam(),
    Adam(learning_rate=0.0001),
    Adam(learning_rate=0.0002),
    SGD(learning_rate=0.0005),
    SGD(learning_rate=0.05),
    SGD(),
    RMSprop(learning_rate=0.0005),
    RMSprop(learning_rate=0.0002),
    RMSprop(learning_rate = 0.0001),
]

name_list = [
    'Adam_LR_0.0001',
    'Adam_LR_0.0002',
    'Adam_LR_0.0005',
    'SGD_LR_0.005',
    'SGD_LR_0.05',
    'SGD_LR_0.01',
    'RMSprop_LR_0.0005',
    'RMSprop_LR_0.0002',
    'RMSprop_LR_0.0001'
]

tuned_Adam_LR_0_0001 = None
tuned_Adam_LR_0_0002 = None
tuned_Adam_LR_0_0005 = None
tuned_SGD_LR_0_005 = None
tuned_SGD_LR_0_05 = None
tuned_SGD_LR_0_01 = None
tuned_RMSprop_LR_0_0005 = None
tuned_RMSprop_LR_0_0002 = None
tuned_RMSprop_LR_0_0001 = None

model_list = [tuned_Adam_LR_0_0001, tuned_Adam_LR_0_0002, tuned_Adam_LR_0_0005, tuned_SGD_LR_0_005, tuned_SGD_LR_0_05, tuned_SGD_LR_0_01, tuned_RMSprop_LR_0_0005, tuned_RMSprop_LR_0_0002, tuned_RMSprop_LR_0_0001]

def model_hypertuner(optimizer, name, model, train_length = 30):
        clear_output(wait=True)
        print(f"Now attempting to tune {name}")
        model = hingeGAN(latent_dim=100)
        model.build(())
        model.load_weights('output/models/hinge_gan/weights/weights_199.h5')
        model.compile(
            d_optimizer=optimizer,
            g_optimizer=optimizer,
            loss_fn=BinaryCrossentropy(),
        )
        model_callback = CustomCallback(d_losses = model.d_loss_list, g_losses = model.g_loss_list, kl_div=model.kl_div_list, model = model, filepath = f"output/models/hypertune/{name}/")
        model.fit(X_train, epochs = train_length, callbacks = [model_callback])
        gc.collect()
        model_images, _ = model.generate_fake_samples(model, model.generator, n_samples = 100, latent_dim = 100)
        model_fid = calcFID(model_images, num_images = 1000)
        return model_fid, model_images
    
output_fid_list = []
output_images_list = []
for optimizer, name, model in zip(optimizer_list, name_list, model_list):
    model, history = model_hypertuner(optimizer, name, model, train_length=30)
    output_fid_list.append(model)
    output_images_list.append(history)
Now attempting to tune RMSprop_LR_0.0001
Epoch 1/30
1562/1562 [==============================] - 58s 36ms/step - d_loss: 0.8175 - g_loss: 1.3469 - kl_divergence: 0.3477
Epoch 2/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7706 - g_loss: 1.3849 - kl_divergence: 0.3518
Epoch 3/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7500 - g_loss: 1.4072 - kl_divergence: 0.3519
Epoch 4/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7411 - g_loss: 1.4317 - kl_divergence: 0.3532
Epoch 5/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7346 - g_loss: 1.4437 - kl_divergence: 0.3529
Epoch 6/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7279 - g_loss: 1.4502 - kl_divergence: 0.3515
Epoch 7/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7164 - g_loss: 1.4729 - kl_divergence: 0.3557
Epoch 8/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7203 - g_loss: 1.4696 - kl_divergence: 0.3653
Epoch 9/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7101 - g_loss: 1.4918 - kl_divergence: 0.3579
Epoch 10/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7082 - g_loss: 1.5021 - kl_divergence: 0.3561
Epoch 11/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7004 - g_loss: 1.5168 - kl_divergence: 0.3538
Epoch 12/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.7009 - g_loss: 1.5403 - kl_divergence: 0.3560
Epoch 13/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6940 - g_loss: 1.5389 - kl_divergence: 0.3687
Epoch 14/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6973 - g_loss: 1.5483 - kl_divergence: 0.3571
Epoch 15/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6922 - g_loss: 1.5561 - kl_divergence: 0.3571
Epoch 16/30
1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.6926 - g_loss: 1.5567 - kl_divergence: 0.3570
Epoch 17/30
1562/1562 [==============================] - 62s 39ms/step - d_loss: 0.6930 - g_loss: 1.5601 - kl_divergence: 0.3521
Epoch 18/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6888 - g_loss: 1.5766 - kl_divergence: 0.3517
Epoch 19/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6896 - g_loss: 1.5757 - kl_divergence: 0.3523
Epoch 20/30
1562/1562 [==============================] - 60s 38ms/step - d_loss: 0.6815 - g_loss: 1.5937 - kl_divergence: 0.3575
Epoch 21/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6826 - g_loss: 1.6008 - kl_divergence: 0.3604
Epoch 22/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6823 - g_loss: 1.6140 - kl_divergence: 0.3562
Epoch 23/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6810 - g_loss: 1.6086 - kl_divergence: 0.3591
Epoch 24/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6675 - g_loss: 1.6305 - kl_divergence: 0.3548
Epoch 25/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6785 - g_loss: 1.6191 - kl_divergence: 0.3641
Epoch 26/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6818 - g_loss: 1.6346 - kl_divergence: 0.3557
Epoch 27/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6763 - g_loss: 1.6464 - kl_divergence: 0.3532
Epoch 28/30
1562/1562 [==============================] - 61s 39ms/step - d_loss: 0.6748 - g_loss: 1.6534 - kl_divergence: 0.3547
Epoch 29/30
1562/1562 [==============================] - 59s 38ms/step - d_loss: 0.6718 - g_loss: 1.6539 - kl_divergence: 0.3544
Epoch 30/30
1562/1562 [==============================] - 60s 38ms/step - d_loss: 0.6735 - g_loss: 1.6584 - kl_divergence: 0.3579
32/32 [==============================] - 4s 88ms/step
32/32 [==============================] - 3s 87ms/step
In [ ]:
import pandas as pd

data = {'Name': name_list, 'Value': copy_arr}
df = pd.DataFrame(data)
plt.figure(figsize=(10,6)) # Optional: Set figure size
chart = sns.barplot(x='Value', y='Name', data=df)
chart.set_xlabel('FID Score', fontdict={'size': 12})
plt.tight_layout()

for index, value in enumerate(copy_arr):
    plt.text(value, index, round(value,3), color='black', va='center')

plt.title('FID Score based on optimizer') # Optional: Set title
plt.show()

From the above plot, we can see that there are 6 optimizer configurations which produce better results. We shall generate 100 images with each optimizer, and visually inspect the images to see which produces the best results.

In [ ]:
target_index = [1,2,3,6,7,8]

for index in target_index:
    print(f"Displaying images for {name_list[index]}")
    all_images = output_images_list[index]
    all_images = (all_images + 1) / 2.0  # Scale images to [0, 1]
    np.random.shuffle(all_images)
    batches = list(batch_images(all_images[:100], 100))

    for batch in batches:
        display_images_in_grid(batch, 10, title = f"Generated Images for {name_list[index]}\nFID = {copy_arr[index]:.3f}")
Displaying images for Adam_LR_0.0002
Displaying images for Adam_LR_0.0005
Displaying images for SGD_LR_0.005
Displaying images for RMSprop_LR_0.0005
Displaying images for RMSprop_LR_0.0002
Displaying images for RMSprop_LR_0.0001

From the above, we can see that all the images look similar. However, the images generated by the RMSprop optimizer with a learning rate of 0.001 seem to be most promising, as they have the most detail out of all, but are more mis-shapen. This shows that with further tuning, the model may be able to better learn the missing details and generate better images.

Data Augmentation

Now, we shall attempt to further improvie the model by performing data augmentation before feeding the images to themodel. Data Augmenatation helps to generate and expand our training data from existing samples by augmenting them using random transformations, such as flipping, cropping and rotating. Hence, exposes model to more aspects of data to generate more realistic images.

Sources:
https://www.tensorflow.org/tutorials/images/data_augmentation
https://keras.io/api/layers/preprocessing_layers/image_augmentation/

We are going to create a function to perform Data Augmentation techniques such as RandomFlip, RandomRotation and RandomCrop.

In [ ]:
from tensorflow.keras.layers.experimental import preprocessing
def augment_dataset(dataset):
    def augment(image, label):
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_flip_up_down(image)
        image = tf.image.rot90(image, tf.random.uniform(shape=[], minval=0, maxval=4, dtype=tf.int32))
        image = tf.image.random_crop(image, size=[image.shape[0] - 4, image.shape[1] - 4, image.shape[2]])
        image = tf.image.resize(image, [32, 32])
        return image, label
    return dataset.map(augment)

X_train_augmented = augment_dataset(X_train.unbatch())

Next, we are going to apply the Data Augmentation on X_train_rescaled.

In [ ]:
plt.figure(figsize = (30,30))
for i, (image, _) in enumerate(X_train_augmented.take(100)):
    image = (image + 1) / 2.0
    plt.subplot(10, 10, i + 1)
    plt.imshow(image.numpy())
    plt.axis('off')
plt.show()

We are now going to combine the original train dataset, X_train_rescaled, & the data augmentated train dataset, X_train_dataAug.

After doing so, we will have double the amount of train images at 99968 (32 images lost due to half-batching).

In [ ]:
X_train_augmented = X_train.unbatch().concatenate(X_train_augmented)
In [ ]:
print(f"Length of initial dataset: {X_train.unbatch().cardinality().numpy()} images")
print(f"Length of augmented dataset: {X_train_augmented.cardinality().numpy()} images")    
Length of initial dataset: 49984 images
Length of augmented dataset: 99968 images

Final Datasets for GAN Model Improvement

Finally, we are going to feed the augmented data to the model for final training. We shall train two models for 30 epochs each, one with augmneted and one with un-augmented data to see the effects of the augmented data on the model.

In [ ]:
optimized_hinge_gan_augmented = hingeGAN(latent_dim=100)
optimized_hinge_gan_augmented.build(())
optimized_hinge_gan_augmented.load_weights('output/models/hypertune/RMSprop_LR_0.0001/weights//weights_29.h5')
optimized_hinge_gan_augmented.compile(
    d_optimizer=RMSprop(learning_rate=0.0001),
    g_optimizer=RMSprop(learning_rate=0.0001),
    loss_fn=Hinge(),
)
hinge_gan_callback = CustomCallback(d_losses = optimized_hinge_gan_augmented.d_loss_list, g_losses = optimized_hinge_gan_augmented.g_loss_list, kl_div=optimized_hinge_gan_augmented.kl_div_list, model = optimized_hinge_gan_augmented, filepath = "output/models/optimized_hinge_gan_augmented/")
optimized_hinge_gan_augmented.fit(X_train_augmented.batch(32, drop_remainder=True), epochs = 30, callbacks = [hinge_gan_callback])
Epoch 1/30
3124/3124 [==============================] - 113s 35ms/step - d_loss: 0.9349 - g_loss: 1.4121 - kl_divergence: 0.2578
Epoch 2/30
3124/3124 [==============================] - 114s 37ms/step - d_loss: 1.0235 - g_loss: 1.3359 - kl_divergence: 0.2421
Epoch 3/30
3124/3124 [==============================] - 111s 36ms/step - d_loss: 1.0222 - g_loss: 1.3069 - kl_divergence: 0.2145
Epoch 4/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0289 - g_loss: 1.2999 - kl_divergence: 0.2144
Epoch 5/30
3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0299 - g_loss: 1.2838 - kl_divergence: 0.2193
Epoch 6/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0394 - g_loss: 1.2809 - kl_divergence: 0.2169
Epoch 7/30
3124/3124 [==============================] - 112s 36ms/step - d_loss: 1.0327 - g_loss: 1.2758 - kl_divergence: 0.2150
Epoch 8/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0340 - g_loss: 1.2634 - kl_divergence: 0.2107
Epoch 9/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0312 - g_loss: 1.2546 - kl_divergence: 0.2120
Epoch 10/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0337 - g_loss: 1.2696 - kl_divergence: 0.2089
Epoch 11/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0454 - g_loss: 1.2416 - kl_divergence: 0.2100
Epoch 12/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0356 - g_loss: 1.2528 - kl_divergence: 0.2134
Epoch 13/30
3124/3124 [==============================] - 112s 36ms/step - d_loss: 1.0337 - g_loss: 1.2572 - kl_divergence: 0.2063
Epoch 14/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0338 - g_loss: 1.2512 - kl_divergence: 0.2067
Epoch 15/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0362 - g_loss: 1.2508 - kl_divergence: 0.2093
Epoch 16/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0273 - g_loss: 1.2512 - kl_divergence: 0.2046
Epoch 17/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0267 - g_loss: 1.2469 - kl_divergence: 0.2124
Epoch 18/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0314 - g_loss: 1.2428 - kl_divergence: 0.2057
Epoch 19/30
3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0306 - g_loss: 1.2419 - kl_divergence: 0.2076
Epoch 20/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0357 - g_loss: 1.2344 - kl_divergence: 0.2067
Epoch 21/30
3124/3124 [==============================] - 112s 36ms/step - d_loss: 1.0303 - g_loss: 1.2390 - kl_divergence: 0.2010
Epoch 22/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0314 - g_loss: 1.2403 - kl_divergence: 0.2154
Epoch 23/30
3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0332 - g_loss: 1.2493 - kl_divergence: 0.2138
Epoch 24/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0334 - g_loss: 1.2367 - kl_divergence: 0.2097
Epoch 25/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0321 - g_loss: 1.2443 - kl_divergence: 0.2077
Epoch 26/30
3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0290 - g_loss: 1.2360 - kl_divergence: 0.2079
Epoch 27/30
3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0315 - g_loss: 1.2444 - kl_divergence: 0.2063
Epoch 28/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0400 - g_loss: 1.2315 - kl_divergence: 0.2025
Epoch 29/30
3124/3124 [==============================] - 111s 35ms/step - d_loss: 1.0379 - g_loss: 1.2260 - kl_divergence: 0.2179
Epoch 30/30
3124/3124 [==============================] - 110s 35ms/step - d_loss: 1.0300 - g_loss: 1.2352 - kl_divergence: 0.2022
In [ ]:
model_images, _ = optimized_hinge_gan_augmented.generate_fake_samples(optimized_hinge_gan_augmented, optimized_hinge_gan_augmented.generator, n_samples = 100, latent_dim = 100)
model_fid = calcFID(model_images, num_images = 1000)
print(f"FID for this model is {model_fid}")
32/32 [==============================] - 4s 80ms/step
32/32 [==============================] - 3s 80ms/step
FID for this model is 92.08255961757423
In [ ]:
optimized_hinge_gan = hingeGAN(latent_dim=100)
optimized_hinge_gan.build(())
optimized_hinge_gan.load_weights('output/models/hypertune/RMSprop_LR_0.0001/weights//weights_29.h5')
optimized_hinge_gan.compile(
    d_optimizer=RMSprop(learning_rate=0.0001),
    g_optimizer=RMSprop(learning_rate=0.0001),
    loss_fn=Hinge(),
)
hinge_gan_callback = CustomCallback(d_losses = optimized_hinge_gan.d_loss_list, g_losses = optimized_hinge_gan.g_loss_list, kl_div=optimized_hinge_gan.kl_div_list, model = optimized_hinge_gan, filepath = "output/models/optimized_hinge_gan/")
optimized_hinge_gan.fit(X_train, epochs = 30, callbacks = [hinge_gan_callback])
Epoch 1/30
1562/1562 [==============================] - 58s 36ms/step - d_loss: 0.6968 - g_loss: 1.6341 - kl_divergence: 0.3676
Epoch 2/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6986 - g_loss: 1.6358 - kl_divergence: 0.3562
Epoch 3/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6957 - g_loss: 1.6435 - kl_divergence: 0.3568
Epoch 4/30
1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.6983 - g_loss: 1.6359 - kl_divergence: 0.3534
Epoch 5/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6955 - g_loss: 1.6453 - kl_divergence: 0.3548
Epoch 6/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6957 - g_loss: 1.6457 - kl_divergence: 0.3520
Epoch 7/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6908 - g_loss: 1.6564 - kl_divergence: 0.3573
Epoch 8/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6943 - g_loss: 1.6470 - kl_divergence: 0.3551
Epoch 9/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6864 - g_loss: 1.6681 - kl_divergence: 0.3562
Epoch 10/30
1562/1562 [==============================] - 57s 37ms/step - d_loss: 0.6934 - g_loss: 1.6683 - kl_divergence: 0.3561
Epoch 11/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6923 - g_loss: 1.6623 - kl_divergence: 0.3554
Epoch 12/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6931 - g_loss: 1.6750 - kl_divergence: 0.3542
Epoch 13/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6851 - g_loss: 1.6814 - kl_divergence: 0.3563
Epoch 14/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6882 - g_loss: 1.6800 - kl_divergence: 0.3591
Epoch 15/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6868 - g_loss: 1.6793 - kl_divergence: 0.3587
Epoch 16/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6901 - g_loss: 1.6842 - kl_divergence: 0.3555
Epoch 17/30
1562/1562 [==============================] - 58s 37ms/step - d_loss: 0.6888 - g_loss: 1.6784 - kl_divergence: 0.3547
Epoch 18/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6900 - g_loss: 1.6878 - kl_divergence: 0.3528
Epoch 19/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6867 - g_loss: 1.6957 - kl_divergence: 0.3564
Epoch 20/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6809 - g_loss: 1.6891 - kl_divergence: 0.3566
Epoch 21/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6833 - g_loss: 1.7030 - kl_divergence: 0.3561
Epoch 22/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6760 - g_loss: 1.7161 - kl_divergence: 0.3544
Epoch 23/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6781 - g_loss: 1.7147 - kl_divergence: 0.3546
Epoch 24/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6727 - g_loss: 1.7177 - kl_divergence: 0.3550
Epoch 25/30
1562/1562 [==============================] - 58s 37ms/step - d_loss: 0.6823 - g_loss: 1.7063 - kl_divergence: 0.3533
Epoch 26/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6836 - g_loss: 1.7184 - kl_divergence: 0.3557
Epoch 27/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6811 - g_loss: 1.7192 - kl_divergence: 0.3527
Epoch 28/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6753 - g_loss: 1.7416 - kl_divergence: 0.3564
Epoch 29/30
1562/1562 [==============================] - 57s 36ms/step - d_loss: 0.6737 - g_loss: 1.7314 - kl_divergence: 0.3550
Epoch 30/30
1562/1562 [==============================] - 56s 36ms/step - d_loss: 0.6846 - g_loss: 1.7264 - kl_divergence: 0.3556
Out[ ]:
<keras.callbacks.History at 0x1a4367166a0>
In [ ]:
model_images, _ = optimized_hinge_gan.generate_fake_samples(optimized_hinge_gan, optimized_hinge_gan.generator, n_samples = 100, latent_dim = 100)
model_fid = calcFID(model_images, num_images = 1000)
print(f"FID for this model is {model_fid}")
32/32 [==============================] - 4s 80ms/step
32/32 [==============================] - 3s 80ms/step
FID for this model is 69.90489118078709
Augmented Results
Unaugmented Results

Comparing the models in terms of metrics and visualization, we can see that the unaugmented model still performs better than the augmneted model. Not only does the unaugmented model hold detail in it's images better than that of the augmented one, it also performs better metric-wise, with a FID of 69 compared to the augmented model's score of 92. Hence, we shall seelct the unaugmented model as our final model, and generate the images required.


Final Image Generation

The images will also be saved in the "/generated" directory.

The final model's weights will also be saved as best_weights.h5, and can be found in the same directory as this file.

In [ ]:
model_images, _ = optimized_hinge_gan.generate_fake_samples(optimized_hinge_gan, optimized_hinge_gan.generator, n_samples = 100, latent_dim = 100)
model_images = (model_images + 1) / 2.0  # Scale images to [0, 1]
batches = list(batch_images(model_images, 100))
for i, batch in enumerate(batches):
    display_images_in_grid(batch, 10, title = f"Generated {class_names[i]} Images for Optimized Hinge GAN\n")

Conclusion

To conclude, this task was a fun and intruiging one for both of us, and allowed us deeper insight into the considerations and limitations that have to be taken into account during the development of such AI tools. We have also learned deepened our understanding, furthered our learning, and solidified our foundation in the field of deep learning, especially on the topic of Generative Adversarial Networks. Despite this assignment being a interesting, there are some ideas and things that we would have loved to try, such as SA-GAN models, or using more ways to stablilize/improve models such as gaussian weights, or even upscaling the data to generate images with higher fidelity.